{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f1297d2d-7c26-427b-b091-4b620421603b",
   "metadata": {},
   "source": [
    "# Libraries & Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4715fe95-1e4e-4a76-95a7-bc45508f4c9a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/user/miniconda3/envs/diploma/bin/python3\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "print(sys.executable)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "32e7fe3a-ba3c-48ec-960f-2631b01392ec",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 23:27:13.133796: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-06-14 23:27:13.133836: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-06-14 23:27:13.134829: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-06-14 23:27:13.141129: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-06-14 23:27:14.022794: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from datetime import datetime\n",
    "\n",
    "# Visualization libraries\n",
    "import seaborn as sns\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly.express as px\n",
    "\n",
    "# ML/DL libraries\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers\n",
    "import tensorflow.keras.applications as models\n",
    "import tensorflow_ranking as tfr\n",
    "import tensorflow_recommenders as tfrs\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.model_selection import train_test_split\n",
    "import keras_tuner as kt\n",
    "\n",
    "# Misc\n",
    "import sys\n",
    "import math\n",
    "import random\n",
    "from joblib import Parallel, delayed\n",
    "import PIL\n",
    "from PIL import Image\n",
    "import os\n",
    "import pathlib\n",
    "import warnings\n",
    "from pprint import pprint\n",
    "from numba import cuda \n",
    "from importlib import reload\n",
    "warnings.filterwarnings('ignore')\n",
    "warnings.simplefilter('ignore')\n",
    "\n",
    "# local\n",
    "import utils\n",
    "reload(utils)\n",
    "from utils import generate_dataset, \\\n",
    "        save_model_with_timestamp_and_lr, \\\n",
    "        predict_rank, compute_metrics, \\\n",
    "        compute_dataset_metrics, \\\n",
    "        plot_metrics, \\\n",
    "        create_dataset\n",
    "\n",
    "import metrics\n",
    "reload(metrics)\n",
    "from metrics import permutation_accuracy, \\\n",
    "        accuracy_by_rank, \\\n",
    "        accuracy_at_k, \\\n",
    "        dcg, \\\n",
    "        ndcg\n",
    "\n",
    "from metrics import NDCGMetric\n",
    "\n",
    "import constants\n",
    "reload(constants)\n",
    "from constants import *\n",
    "\n",
    "# models\n",
    "import models\n",
    "reload(models)\n",
    "from models import ResNet50RankingModel\n",
    "\n",
    "from models import EfficientNetB0RankingModel\n",
    "from models import EfficientNetB1RankingModel\n",
    "from models import EfficientNetB2RankingModel\n",
    "from models import EfficientNetB3RankingModel\n",
    "from models import EfficientNetB4RankingModel\n",
    "from models import EfficientNetB5RankingModel\n",
    "from models import EfficientNetB6RankingModel\n",
    "from models import EfficientNetB7RankingModel\n",
    "\n",
    "from models import MobileNetV2RankingModel\n",
    "\n",
    "from models import VGG16RankingModel\n",
    "from models import VGG19RankingModel\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'\n",
    "\n",
    "gpus = tf.config.experimental.list_physical_devices('GPU')\n",
    "if gpus:\n",
    "    try:\n",
    "        for gpu in gpus:\n",
    "            tf.config.experimental.set_memory_growth(gpu, True)\n",
    "    except RuntimeError as e:\n",
    "        print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ebd2373-c4d9-4199-b21d-e216ab6bc711",
   "metadata": {},
   "source": [
    "## Reload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 411,
   "id": "348d1be0-63ce-46b3-9182-98d5456b5025",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "273024"
      ]
     },
     "execution_count": 411,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gc\n",
    "tf.keras.backend.clear_session()\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bc8f65f-baf3-4605-889f-d9cea83df91a",
   "metadata": {},
   "source": [
    "## Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7f639802-7f41-4df8-9e29-91bf2416762c",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(RANDOM_SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2650add3-6395-46c3-a2f1-1c3ee564e1fd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensorflow Version: 2.15.1\n",
      "Num GPUs Available: \n",
      "[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]\n",
      "Default GPU Device: /device:GPU:0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 23:27:15.774812: I tensorflow/core/common_runtime/gpu/gpu_process_state.cc:236] Using CUDA malloc Async allocator for GPU: 0\n",
      "2024-06-14 23:27:15.775224: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /device:GPU:0 with 31137 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:08:00.0, compute capability: 7.0\n",
      "2024-06-14 23:27:15.777238: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /device:GPU:0 with 31137 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:08:00.0, compute capability: 7.0\n"
     ]
    }
   ],
   "source": [
    "print(\"Tensorflow Version:\", tf.__version__)\n",
    "print(\"Num GPUs Available: \")\n",
    "pprint(tf.config.list_physical_devices('GPU'))\n",
    "\n",
    "if tf.test.gpu_device_name():\n",
    "    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))\n",
    "else:\n",
    "    print(\"Please install GPU version of TF\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8fc8d454-ac3a-4fc2-bd0a-266a79867224",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 22:28:06.367591: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31137 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:08:00.0, compute capability: 7.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 10min 4s, sys: 29.3 s, total: 10min 34s\n",
      "Wall time: 10min 20s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "train_dataset, test_dataset, val_dataset = create_dataset(DATASET_DIRECTORY)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2f8858a4-2e73-47cc-b0b7-e8daaa6fa7bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.device('/gpu:0'):\n",
    "    test_dataset = test_dataset.cache()\n",
    "    val_dataset = val_dataset.cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b335266e-63dc-4976-93dc-9d0cd4d537a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train dataset size: 125\n",
      "Test dataset size: 27\n",
      "Validation dataset size: 28\n",
      "CPU times: user 32min 47s, sys: 2min 55s, total: 35min 43s\n",
      "Wall time: 35min 8s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "print(\"Train dataset size:\", sum(1 for _ in train_dataset))\n",
    "print(\"Test dataset size:\", sum(1 for _ in test_dataset))\n",
    "print(\"Validation dataset size:\", sum(1 for _ in val_dataset))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0583c26b-5848-4d68-bf12-7cf039ce671e",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Training models & evaluation (without unfreezing)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3380174f-741d-40aa-8efb-e24c41309543",
   "metadata": {},
   "source": [
    "## ResNet50"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7aa8116c-a1c5-4167-b01e-76355aef0398",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cd730029-ca9f-4ebf-befa-59b44ca79aa3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.92 s, sys: 180 ms, total: 2.1 s\n",
      "Wall time: 2.06 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "loss = tfr.keras.losses.ListMLELoss()\n",
    "model = ResNet50RankingModel(loss)\n",
    "lr = 1e-5\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0fe85640-24ee-4b53-98a7-ccd09b3ef5b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3.15 ms, sys: 25 µs, total: 3.17 ms\n",
      "Wall time: 1.36 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "batch = train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a4786fdf-963d-472e-801d-72f8b213dd8f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-11 05:54:46.680394: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8902\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 12s 12s/step\n",
      "CPU times: user 15.3 s, sys: 1.68 s, total: 17 s\n",
      "Wall time: 12.5 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-11 05:54:52.559936: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15404740284133898513\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[-0.8131649 , -0.8312141 , -0.7823399 , -0.76662964, -0.83060336],\n",
       "       [-0.8723441 , -0.7641825 , -0.86517143, -0.6330479 , -0.5462649 ],\n",
       "       [-0.74472916, -0.9487216 , -0.7489414 , -0.8474132 , -0.79497707],\n",
       "       [-0.80689466, -0.7841099 , -0.7414913 , -0.8287448 , -0.7996489 ],\n",
       "       [-0.80470514, -0.75783455, -0.53529215, -0.7120016 , -0.85806936],\n",
       "       [-0.86850715, -0.81858444, -0.8176987 , -0.8241544 , -0.8195172 ],\n",
       "       [-0.8932461 , -0.81239706, -0.78819036, -0.70734936, -0.8799973 ],\n",
       "       [-0.7578581 , -0.6725249 , -0.8746257 , -0.790603  , -0.8086719 ],\n",
       "       [-0.9454566 , -0.8274246 , -0.9985315 , -0.8613885 , -0.67679614],\n",
       "       [-0.77525944, -0.7753032 , -0.81685084, -0.7747258 , -0.84702706],\n",
       "       [-0.68176544, -0.8439382 , -0.8045942 , -0.9340223 , -0.77534986],\n",
       "       [-0.8710236 , -1.0160403 , -0.8729487 , -0.73366237, -0.6880341 ],\n",
       "       [-0.8325136 , -0.8522108 , -0.7896058 , -0.82373226, -0.8154161 ],\n",
       "       [-0.762455  , -0.8594922 , -0.83490336, -0.6663109 , -0.8009621 ],\n",
       "       [-0.8132144 , -0.8150547 , -0.81568885, -0.8444403 , -0.84916687],\n",
       "       [-0.77514976, -0.81604296, -0.7742555 , -0.8355501 , -0.79522216],\n",
       "       [-0.76030123, -0.798618  , -0.7653465 , -0.8780537 , -0.8008087 ],\n",
       "       [-0.8483817 , -0.7970946 , -0.81604767, -0.740711  , -0.81399137],\n",
       "       [-0.8271253 , -1.0792475 , -0.92031074, -0.87774104, -0.8064679 ],\n",
       "       [-0.76187027, -0.8480998 , -0.80592334, -0.82507026, -0.8704524 ],\n",
       "       [-0.7873    , -0.7677481 , -0.8709984 , -0.87696433, -0.8423499 ],\n",
       "       [-0.810956  , -0.82103395, -0.8267464 , -0.8324322 , -0.81557256],\n",
       "       [-0.71777415, -0.6791558 , -0.6811663 , -0.74739754, -0.7531697 ],\n",
       "       [-0.8182075 , -0.791476  , -0.85866743, -0.86868805, -0.83058345],\n",
       "       [-0.8417996 , -0.8184751 , -0.8228353 , -0.7572077 , -0.86746275],\n",
       "       [-0.6565971 , -0.5594988 , -0.87580514, -0.799642  , -0.94449353],\n",
       "       [-0.82723045, -0.94959927, -0.86367714, -0.91183054, -0.92053896],\n",
       "       [-1.0084435 , -0.8963794 , -0.8424709 , -0.8243398 , -0.78985685],\n",
       "       [-0.6970963 , -0.6780878 , -0.74347055, -0.7533542 , -0.80340946],\n",
       "       [-0.81443405, -0.79679763, -0.7612849 , -0.92059875, -0.9075012 ]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.predict(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "386b98dd-481d-4b2c-a46e-0e5a148e37b2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"res_net50_ranking_model\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " resnet50 (Functional)       (None, 7, 7, 2048)        23587712  \n",
      "                                                                 \n",
      " flatten (Flatten)           multiple                  0         \n",
      "                                                                 \n",
      " sequential (Sequential)     (None, 64)                51553216  \n",
      "                                                                 \n",
      " sequential_1 (Sequential)   (None, 1)                 65        \n",
      "                                                                 \n",
      " ranking (Ranking)           multiple                  0 (unused)\n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 75140993 (286.64 MB)\n",
      "Trainable params: 51553281 (196.66 MB)\n",
      "Non-trainable params: 23587712 (89.98 MB)\n",
      "_________________________________________________________________\n",
      "CPU times: user 20.3 ms, sys: 3.55 ms, total: 23.8 ms\n",
      "Wall time: 21.4 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4e357e37-3bc9-45ee-8510-90fdbd61b62d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-11 05:55:06.264835: I external/local_xla/xla/service/service.cc:168] XLA service 0x7fa0cdbe5810 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n",
      "2024-06-11 05:55:06.264879: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce RTX 2080 Ti, Compute Capability 7.5\n",
      "2024-06-11 05:55:06.272021: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "I0000 00:00:1718085306.390981 4093947 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 609s 5s/step - ndcg_metric: 0.7932 - mrr_metric: 0.9475 - opa_metric: 0.6568 - loss: 4.3976 - regularization_loss: 0.0000e+00 - total_loss: 4.3976\n",
      "Epoch 2/10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-11 06:05:01.804640: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12090041538308137160\n",
      "2024-06-11 06:05:01.804692: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 336024244618928156\n",
      "2024-06-11 06:05:01.804710: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13592514152826954238\n",
      "2024-06-11 06:05:01.804722: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9150459179615854783\n",
      "2024-06-11 06:05:01.804736: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14124643007598854667\n",
      "2024-06-11 06:05:01.804759: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5986201689236825281\n",
      "2024-06-11 06:05:01.804766: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12052128305733099475\n",
      "2024-06-11 06:05:01.804775: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1191804465972497077\n",
      "2024-06-11 06:05:01.804782: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17502760725129988841\n",
      "2024-06-11 06:05:01.804790: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1165631904551291993\n",
      "2024-06-11 06:05:01.804803: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5986304254054436011\n",
      "2024-06-11 06:05:01.804810: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 80806119737437509\n",
      "2024-06-11 06:05:01.804818: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4014686933013434363\n",
      "2024-06-11 06:05:01.804825: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7837847095135316601\n",
      "2024-06-11 06:05:01.804833: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1255447990177992975\n",
      "2024-06-11 06:05:01.804839: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10357641603270792359\n",
      "2024-06-11 06:05:01.804848: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12102712943343597617\n",
      "2024-06-11 06:05:01.804855: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15404740284133898513\n",
      "2024-06-11 06:05:01.804861: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4214573948867264407\n",
      "2024-06-11 06:05:01.804878: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8134687833692168376\n",
      "2024-06-11 06:05:01.804888: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10209280755183404322\n",
      "2024-06-11 06:05:01.804896: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3770310791444242484\n",
      "2024-06-11 06:05:01.804902: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7455783319392503964\n",
      "2024-06-11 06:05:01.804910: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1763042596492962274\n",
      "2024-06-11 06:05:01.804916: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11851190110330389920\n",
      "2024-06-11 06:05:01.804923: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10227197255373655302\n",
      "2024-06-11 06:05:01.804930: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10660690237430099926\n",
      "2024-06-11 06:05:01.804940: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12613588347451748438\n",
      "2024-06-11 06:05:01.804947: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2274281497971218542\n",
      "2024-06-11 06:05:01.804955: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2124805057338562956\n",
      "2024-06-11 06:05:01.804962: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10614511526805838040\n",
      "2024-06-11 06:05:01.805028: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15548532670003461848\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 596s 5s/step - ndcg_metric: 0.8176 - mrr_metric: 0.9564 - opa_metric: 0.6921 - loss: 4.2175 - regularization_loss: 0.0000e+00 - total_loss: 4.2175\n",
      "Epoch 3/10\n",
      "125/125 [==============================] - 591s 5s/step - ndcg_metric: 0.8286 - mrr_metric: 0.9633 - opa_metric: 0.7076 - loss: 4.1226 - regularization_loss: 0.0000e+00 - total_loss: 4.1226\n",
      "Epoch 4/10\n",
      "125/125 [==============================] - 596s 5s/step - ndcg_metric: 0.8345 - mrr_metric: 0.9657 - opa_metric: 0.7154 - loss: 4.0497 - regularization_loss: 0.0000e+00 - total_loss: 4.0497\n",
      "Epoch 5/10\n",
      "125/125 [==============================] - 597s 5s/step - ndcg_metric: 0.8408 - mrr_metric: 0.9704 - opa_metric: 0.7236 - loss: 3.9890 - regularization_loss: 0.0000e+00 - total_loss: 3.9890\n",
      "Epoch 6/10\n",
      "125/125 [==============================] - 592s 5s/step - ndcg_metric: 0.8453 - mrr_metric: 0.9716 - opa_metric: 0.7294 - loss: 3.9338 - regularization_loss: 0.0000e+00 - total_loss: 3.9338\n",
      "Epoch 7/10\n",
      "125/125 [==============================] - 599s 5s/step - ndcg_metric: 0.8491 - mrr_metric: 0.9733 - opa_metric: 0.7360 - loss: 3.8820 - regularization_loss: 0.0000e+00 - total_loss: 3.8820\n",
      "Epoch 8/10\n",
      "125/125 [==============================] - 588s 5s/step - ndcg_metric: 0.8530 - mrr_metric: 0.9744 - opa_metric: 0.7406 - loss: 3.8314 - regularization_loss: 0.0000e+00 - total_loss: 3.8314\n",
      "Epoch 9/10\n",
      "125/125 [==============================] - 588s 5s/step - ndcg_metric: 0.8561 - mrr_metric: 0.9760 - opa_metric: 0.7458 - loss: 3.7867 - regularization_loss: 0.0000e+00 - total_loss: 3.7867\n",
      "Epoch 10/10\n",
      "125/125 [==============================] - 582s 5s/step - ndcg_metric: 0.8594 - mrr_metric: 0.9775 - opa_metric: 0.7507 - loss: 3.7437 - regularization_loss: 0.0000e+00 - total_loss: 3.7437\n",
      "CPU times: user 1h 31min 50s, sys: 8min 8s, total: 1h 39min 58s\n",
      "Wall time: 1h 39min 1s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(train_dataset, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "37f49961-c41b-486b-b4a4-adaa4b27fe7f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 23 µs, sys: 0 ns, total: 23 µs\n",
      "Wall time: 45.3 µs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "401de7a4-d28f-4968-a5ee-fcecc289428f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "File \u001b[0;32m<timed exec>:4\u001b[0m\n",
      "File \u001b[0;32m~/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow/python/data/ops/iterator_ops.py:810\u001b[0m, in \u001b[0;36mOwnedIterator.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    808\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__next__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    809\u001b[0m   \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 810\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_internal\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    811\u001b[0m   \u001b[38;5;28;01mexcept\u001b[39;00m errors\u001b[38;5;241m.\u001b[39mOutOfRangeError:\n\u001b[1;32m    812\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m\n",
      "File \u001b[0;32m~/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow/python/data/ops/iterator_ops.py:773\u001b[0m, in \u001b[0;36mOwnedIterator._next_internal\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    770\u001b[0m \u001b[38;5;66;03m# TODO(b/77291417): This runs in sync mode as iterators use an error status\u001b[39;00m\n\u001b[1;32m    771\u001b[0m \u001b[38;5;66;03m# to communicate that there is no more data to iterate over.\u001b[39;00m\n\u001b[1;32m    772\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context\u001b[38;5;241m.\u001b[39mexecution_mode(context\u001b[38;5;241m.\u001b[39mSYNC):\n\u001b[0;32m--> 773\u001b[0m   ret \u001b[38;5;241m=\u001b[39m \u001b[43mgen_dataset_ops\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43miterator_get_next\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    774\u001b[0m \u001b[43m      \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_iterator_resource\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    775\u001b[0m \u001b[43m      \u001b[49m\u001b[43moutput_types\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_flat_output_types\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    776\u001b[0m \u001b[43m      \u001b[49m\u001b[43moutput_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_flat_output_shapes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    778\u001b[0m   \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    779\u001b[0m     \u001b[38;5;66;03m# Fast path for the case `self._structure` is not a nested structure.\u001b[39;00m\n\u001b[1;32m    780\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_element_spec\u001b[38;5;241m.\u001b[39m_from_compatible_tensor_list(ret)  \u001b[38;5;66;03m# pylint: disable=protected-access\u001b[39;00m\n",
      "File \u001b[0;32m~/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow/python/ops/gen_dataset_ops.py:3024\u001b[0m, in \u001b[0;36miterator_get_next\u001b[0;34m(iterator, output_types, output_shapes, name)\u001b[0m\n\u001b[1;32m   3022\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tld\u001b[38;5;241m.\u001b[39mis_eager:\n\u001b[1;32m   3023\u001b[0m   \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3024\u001b[0m     _result \u001b[38;5;241m=\u001b[39m \u001b[43mpywrap_tfe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTFE_Py_FastPathExecute\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   3025\u001b[0m \u001b[43m      \u001b[49m\u001b[43m_ctx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mIteratorGetNext\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moutput_types\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_types\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3026\u001b[0m \u001b[43m      \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moutput_shapes\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_shapes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3027\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m _result\n\u001b[1;32m   3028\u001b[0m   \u001b[38;5;28;01mexcept\u001b[39;00m _core\u001b[38;5;241m.\u001b[39m_NotOkStatusException \u001b[38;5;28;01mas\u001b[39;00m e:\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = train_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "443e5e14-041e-43b7-929d-03393e6e6430",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "test_metrics = model.evaluate(test_images, test_labels, return_dict=True)\n",
    "print(f\"Test metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa31d3aa-a8ac-4bef-b6cc-f7e0be6592d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "val_metrics = model.evaluate(val_images, val_labels, return_dict=True)\n",
    "print(f\"Val metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a370d99e-c018-46cd-a8ca-1ca4e0488e42",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(100352, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa4cc11caf0>, 140349681267088), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(100352, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa4cc11caf0>, 140349681267088), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa4cc11f190>, 140349681413504), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa4cc11f190>, 140349681413504), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa5b4ff3dc0>, 140349678970768), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa5b4ff3dc0>, 140349678970768), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa51c129b70>, 140349681263888), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa51c129b70>, 140349681263888), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa51c128130>, 140349678960064), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa51c128130>, 140349678960064), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa51c12bd30>, 140349678959744), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa51c12bd30>, 140349678959744), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa509d3d750>, 140349678975008), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa509d3d750>, 140349678975008), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa509d3db70>, 140349678975488), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa509d3db70>, 140349678975488), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa509d3e410>, 140349679027024), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa509d3e410>, 140349679027024), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa568161c90>, 140349679027584), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa568161c90>, 140349679027584), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(100352, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa4cc11caf0>, 140349681267088), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(100352, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa4cc11caf0>, 140349681267088), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa4cc11f190>, 140349681413504), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa4cc11f190>, 140349681413504), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa5b4ff3dc0>, 140349678970768), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa5b4ff3dc0>, 140349678970768), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa51c129b70>, 140349681263888), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa51c129b70>, 140349681263888), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa51c128130>, 140349678960064), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa51c128130>, 140349678960064), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa51c12bd30>, 140349678959744), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa51c12bd30>, 140349678959744), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa509d3d750>, 140349678975008), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa509d3d750>, 140349678975008), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa509d3db70>, 140349678975488), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa509d3db70>, 140349678975488), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa509d3e410>, 140349679027024), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa509d3e410>, 140349679027024), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa568161c90>, 140349679027584), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa568161c90>, 140349679027584), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: saved_models/ResNet50RankingModel_20240611_073401/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: saved_models/ResNet50RankingModel_20240611_073401/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved in saved_models/ResNet50RankingModel_20240611_073401 as ResNet50RankingModel_20240611_073401\n",
      "CPU times: user 17.1 s, sys: 2.15 s, total: 19.3 s\n",
      "Wall time: 14.5 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if SAVE_MODEL:\n",
    "    save_model_with_timestamp_and_lr(model, lr, freezed=\"freezed\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec805a90-ebf3-4743-b036-bfcd5a0470da",
   "metadata": {},
   "source": [
    "### Evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 376,
   "id": "28a04c7d-d911-449e-900b-8f23e2393aef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 4.75 s, sys: 542 ms, total: 5.29 s\n",
      "Wall time: 4.94 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "loaded_model = tf.keras.models.load_model('tesla_saved_models/ResNet50RankingModel_20240613_221115_freezed_0.001', \n",
    "                                          custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 377,
   "id": "b7490405-eff6-4363-b12f-1a194a436d4c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 1s 844ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 1it [00:04,  4.25s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 94ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 2it [00:07,  3.43s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 92ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 3it [00:10,  3.54s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 105ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 4it [00:14,  3.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 5it [00:17,  3.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 85ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 6it [00:21,  3.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 88ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 7it [00:25,  3.59s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 88ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 8it [00:28,  3.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 91ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 9it [00:31,  3.34s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 100ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 10it [00:34,  3.33s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 87ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 11it [00:37,  3.22s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 94ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 12it [00:40,  3.18s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 98ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 13it [00:44,  3.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 103ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 14it [00:48,  3.45s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 15it [00:51,  3.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 16it [00:54,  3.27s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 101ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 17it [00:58,  3.53s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 111ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 18it [01:02,  3.58s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 107ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 19it [01:06,  3.65s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 94ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 20it [01:09,  3.56s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 95ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 21it [01:12,  3.45s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 22it [01:16,  3.52s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 23it [01:19,  3.52s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 82ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 24it [01:23,  3.64s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 103ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 25it [01:27,  3.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 102ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 26it [01:31,  3.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 27it [01:34,  3.58s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 86ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 28it [01:38,  3.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 98ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 29it [01:42,  3.71s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 86ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 30it [01:45,  3.60s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 95ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 31it [01:49,  3.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 100ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 32it [01:52,  3.54s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 88ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 33it [01:56,  3.47s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 93ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 34it [01:59,  3.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 91ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 35it [02:03,  3.51s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 99ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 36it [02:06,  3.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 101ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 37it [02:09,  3.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 99ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 38it [02:13,  3.46s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 92ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 39it [02:17,  3.49s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 101ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 40it [02:21,  3.77s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 41it [02:25,  3.77s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 91ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 42it [02:28,  3.58s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 43it [02:31,  3.52s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 90ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 44it [02:35,  3.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 99ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 45it [02:39,  3.56s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 103ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 46it [02:42,  3.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 120ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 47it [02:45,  3.44s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 95ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 48it [02:49,  3.47s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 49it [02:53,  3.54s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 106ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 50it [02:56,  3.48s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 51it [02:59,  3.43s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 107ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 52it [03:03,  3.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 101ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 53it [03:06,  3.46s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 89ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 54it [03:09,  3.35s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 104ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 55it [03:13,  3.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 101ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 56it [03:16,  3.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 57it [03:20,  3.43s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 90ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 58it [03:23,  3.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 95ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 59it [03:27,  3.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 95ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 60it [03:30,  3.48s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 104ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 61it [03:33,  3.30s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 85ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 62it [03:37,  3.34s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 99ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 63it [03:40,  3.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 95ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 64it [03:44,  3.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 101ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 65it [03:47,  3.47s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 90ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 66it [03:51,  3.49s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 99ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 67it [03:54,  3.48s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 68it [03:57,  3.33s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 95ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 69it [04:01,  3.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 102ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 70it [04:04,  3.31s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 88ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 71it [04:07,  3.31s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 95ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 72it [04:10,  3.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 92ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 73it [04:13,  3.19s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 94ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 74it [04:17,  3.44s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 82ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 75it [04:21,  3.44s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 106ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 76it [04:24,  3.30s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 101ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 77it [04:27,  3.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 93ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 78it [04:31,  3.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 111ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 79it [04:35,  3.50s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 108ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 80it [04:38,  3.61s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 90ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 81it [04:42,  3.47s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 100ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 82it [04:45,  3.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 101ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 83it [04:48,  3.50s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 93ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 84it [04:52,  3.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 91ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 85it [04:55,  3.29s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 107ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 86it [04:58,  3.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 87it [05:01,  3.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 92ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 88it [05:04,  3.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 85ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 89it [05:07,  3.14s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 103ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 90it [05:11,  3.30s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 100ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 91it [05:14,  3.44s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 95ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 92it [05:18,  3.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 100ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 93it [05:21,  3.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 103ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 94it [05:25,  3.51s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 108ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 95it [05:28,  3.48s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 86ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 96it [05:32,  3.60s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 99ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 97it [05:36,  3.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 98ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 98it [05:40,  3.65s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 92ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 99it [05:43,  3.52s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 90ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 100it [05:46,  3.49s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 83ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 101it [05:50,  3.66s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 101ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 102it [05:54,  3.56s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 99ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 103it [05:57,  3.48s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 99ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 104it [06:01,  3.56s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 105it [06:04,  3.61s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 93ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 106it [06:08,  3.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 95ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 107it [06:11,  3.59s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 100ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 108it [06:15,  3.64s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 100ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 109it [06:19,  3.64s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 92ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 110it [06:22,  3.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 111it [06:26,  3.71s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 110ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 112it [06:30,  3.66s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 104ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 113it [06:33,  3.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 104ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 114it [06:37,  3.60s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 89ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 115it [06:40,  3.53s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 102ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 116it [06:44,  3.61s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 103ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 117it [06:47,  3.48s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 118it [06:50,  3.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 93ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 119it [06:54,  3.48s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 120it [06:58,  3.61s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 105ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 121it [07:01,  3.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 99ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 122it [07:05,  3.47s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 113ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 123it [07:08,  3.46s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 104ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 124it [07:12,  3.46s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 125it [07:18,  3.51s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 7min 40s, sys: 21 s, total: 8min 1s\n",
      "Wall time: 7min 18s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 378,
   "id": "677a203b-5e6d-4792-accd-42859a482990",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'NDCG': 0.9158813069862075, 'permutation_accuracy': <tf.Tensor: shape=(), dtype=float32, numpy=0.43706658>, 'Kendalls Tau': 0.5082133333333333, 'Spearmans Rho': 0.5887733333333333, 'OPA': 0.7541066666666666, 'accuracy_by_rank': array([0.60320001, 0.41066667, 0.34026668, 0.33973334, 0.49146667]), 'accuracy@k': array([       nan, 0.49146667, 0.67013333, 0.80453333, 0.9008    ])}\n",
      "CPU times: user 1.04 ms, sys: 0 ns, total: 1.04 ms\n",
      "Wall time: 787 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 379,
   "id": "262c85fa-38e5-4db0-9167-2710774f800b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 86ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 1it [00:00,  3.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 94ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 2it [00:00,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 3it [00:00,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 89ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 4it [00:01,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 95ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 5it [00:01,  3.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 6it [00:01,  3.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 7it [00:02,  3.19it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 93ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 8it [00:02,  3.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 9it [00:02,  3.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 10it [00:03,  3.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 11it [00:03,  3.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 12it [00:03,  3.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 86ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 13it [00:04,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 91ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 14it [00:04,  3.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 89ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 15it [00:04,  3.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 16it [00:04,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 95ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 17it [00:05,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 98ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 18it [00:05,  3.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 93ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 19it [00:05,  3.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 94ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 20it [00:06,  3.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 94ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 21it [00:06,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 22it [00:06,  3.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 82ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 23it [00:07,  3.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 24it [00:07,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 25it [00:07,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 99ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 26it [00:08,  3.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 94ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 27it [00:08,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 28it [00:08,  3.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 7.95 s, sys: 1.1 s, total: 9.05 s\n",
      "Wall time: 8.66 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 380,
   "id": "b6ebe212-3e6d-4722-868d-a46e03544986",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'NDCG': 0.8933689511603438, 'permutation_accuracy': <tf.Tensor: shape=(), dtype=float32, numpy=0.3797619>, 'Kendalls Tau': 0.40023809523809517, 'Spearmans Rho': 0.4689285714285714, 'OPA': 0.7001190476190476, 'accuracy_by_rank': array([0.54047619, 0.35714286, 0.30357144, 0.28928572, 0.40833334]), 'accuracy@k': array([       nan, 0.40833333, 0.61011905, 0.75952381, 0.88511905])}\n",
      "CPU times: user 555 μs, sys: 21 μs, total: 576 μs\n",
      "Wall time: 456 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 381,
   "id": "cc80ed3c-2a6c-465c-9aa1-7b518573d4af",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 88ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 1it [00:00,  3.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 88ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 2it [00:00,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 3it [00:00,  3.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 92ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 4it [00:01,  3.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 90ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 5it [00:01,  3.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 89ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 6it [00:01,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 95ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 7it [00:02,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 8it [00:02,  3.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 98ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 9it [00:02,  3.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 10it [00:03,  3.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 11it [00:03,  3.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 12it [00:03,  3.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 98ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 13it [00:04,  3.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 14it [00:04,  3.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 15it [00:04,  3.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 95ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 16it [00:04,  3.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 94ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 17it [00:05,  3.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 18it [00:05,  3.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 96ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 19it [00:05,  3.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 20it [00:06,  3.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 94ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 21it [00:06,  3.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 98ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 22it [00:06,  3.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 105ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 23it [00:07,  3.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 24it [00:07,  3.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 25it [00:07,  3.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 90ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 26it [00:08,  3.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 27it [00:08,  3.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 7.65 s, sys: 1.09 s, total: 8.74 s\n",
      "Wall time: 8.33 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 382,
   "id": "0605f659-add9-4761-afea-28170deb3953",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'NDCG': 0.8922528057539286, 'permutation_accuracy': <tf.Tensor: shape=(), dtype=float32, numpy=0.37481487>, 'Kendalls Tau': 0.38419753086419745, 'Spearmans Rho': 0.45580246913580236, 'OPA': 0.6920987654320989, 'accuracy_by_rank': array([0.51728396, 0.36666667, 0.30617285, 0.28765433, 0.3962963 ]), 'accuracy@k': array([       nan, 0.3962963 , 0.60555556, 0.76213992, 0.87932099])}\n",
      "CPU times: user 633 μs, sys: 25 μs, total: 658 μs\n",
      "Wall time: 534 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 383,
   "id": "71c69f7d-4780-46c8-8925-ca7a47a3c0d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 6 μs, sys: 1e+03 ns, total: 7 μs\n",
      "Wall time: 12.2 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 384,
   "id": "b2bd163d-99dc-41c7-abd7-1676a445f467",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n",
      "tf.Tensor([829.56146 829.23553 829.4973  829.7194  828.1161 ], shape=(5,), dtype=float32) tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32) tf.Tensor([3 1 2 4 0], shape=(5,), dtype=int32)\n",
      "tf.Tensor([830.4785  830.2989  830.2312  829.06323 829.3869 ], shape=(5,), dtype=float32) tf.Tensor([4 3 1 0 2], shape=(5,), dtype=int32) tf.Tensor([4 3 2 0 1], shape=(5,), dtype=int32)\n",
      "tf.Tensor([829.67834 830.2779  829.4664  828.8362  829.6983 ], shape=(5,), dtype=float32) tf.Tensor([0 3 1 2 4], shape=(5,), dtype=int32) tf.Tensor([2 4 1 0 3], shape=(5,), dtype=int32)\n",
      "CPU times: user 814 ms, sys: 4.08 ms, total: 818 ms\n",
      "Wall time: 817 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = test_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 385,
   "id": "0887ab07-578c-4220-84f4-21f58c844e7a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABW0AAAI+CAYAAADHKihkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAABkW0lEQVR4nO3debyWY/4H8M9pO4VKpEhR2WIQkmQNkS2TLbJE1rFGY8Yue5axb2Gso2QZ69jGhGmMXbKNPWQrsrRS1Pn90c8ZZ85p49S50/v9ej2vca7nXr738zw13/Ppeq67pKysrCwAAAAAABRCrZouAAAAAACA/xLaAgAAAAAUSJ2aLgAAoCZMmDAhb7zxRj7++ONMmTKlpssB/kdpaWlatmyZVVddNQ0bNqzpcgAA5qsSa9oCAAubp556Kn9/+G/J9Klp0XzJ1K9fLyUlNV0V8KOysuS776bm0zFfJrXqZautt88GG2xQ02UBAMw3ZtoCAAuVF198MQ8/eG86rb1yNuq8bho0qF/TJQEz8e233+XJp1/Iww/em9LS0nTo0KGmSwIAmC+saQsALDTKysry1FNPZtUVWqTrZhsKbKHgGjSon66bbZhVV2iRp5/+d02XAwAw3whtAYCFxtixY/PFmE+zxuqrpMR6CLBAKCkpyRqrr5LPR3+SL774oqbLAQCYL4S2AMBC4+uvv07KpmXpZkvVdCnAXFi62VJJ2bR88803NV0KAMB8IbQFABYa06ZNS8qSunUt6w8Lkjp1aidlyQ8//FDTpQAAzBdCWwAAoNAsZwIALGyEtgAAlLt58F0pXaJdGi2zZj75dEyl57fsvnfW3qB7hbGV22+e0iXapXSJdqm/5Kpp1rpj1tmwew456uQ898LLMz3Xd99NySVX3piNuvbMUsuvm0bLrJnfdOyWvn88PW+/+36l7V99/a0ceNjxWXmtLdJomTWzRKt10nGTHjm+//kZ+cFHv+i6P/l0TPboc1Sate6Ypst1yM57HjrHx5w+fXquuWFIOm7SI0u0WietVtkw3Xc9ME8/O7zStsNHvJbtdzkgTZfrkCWXWyfb7rRfXn71jV9UOwAAvz6+GwgAQCVTpkzN+Zdck4vPPXmOtm+/xqo56rA+SZIJEyflzbffy133PpLrb74jRx6yT84/6/gK24/98ut03/WADB/xerbt1iW77bJ9Flt0kbz97vu5464Hc91Nd2TimFfLt7/upttzxDGnpemSTbL7LttnlZXa5odp0/L6G+9k0G335rKBN2fcpyNSu3btub7WiRMnZavf9s748RPzx34Hp26dOrnsqpuy5fZ757lhd2fJJZrMcv/jTjkvl1x5Y/bouUMO3q9Xvhk/IX++8bZ07d47Tzw0OB07rJkkeenl17PZtnum5bJL58Q/Hpbp08ty9fWD03X7vfPkP27PKiu1nevaAQD4dRLaAgBQSfs1Vs31N9+RPx51UFos03y227dYpln26LlDhbGz+x+T3gcdk0uvuikrrtA6B+/Xq/y5Aw47LiNeeSNDbrwkO+7QrcJ+p57QN6eceVH5z08/OzxHHHNaNui0du6+dWAaNlyswvbnnXFszrlg4M+5zCTJwOtvzbvvfZh//+OOrLvOGkmSbl03yTobds/FV9yQM07uN9N9f/jhh1xzw5DstEO33DDwvPLxnX+7ddqt3TW33nF/eWh76tmXpkH9+hn2yJDyIHiPnt2z+npb55QzLsptN1/2s68BAIBfF8sjAABQybFHH5xp06bn/Euu/dnHaNCgfm646tws0aRxzr1gYMrKypIkz73wch76+z+z7147Vwpsk6S0tF7OPePY8p/PPO+KlJSU5Mar/1QpsE2S+vVLc+qJfSvMsp08+du8+fbIjP3y69nWefd9j2TdddYoD2yTpN3KbbPZJuvnznsenuW+33//Q7799rs0a7ZkhfFmTZdIrVq10qBB/fKxfz/9QjbftHOFmbvLLN0sG2/QMQ/+/YlMnDhptrUCALBwENoCAFBJ6+WXzV67/TbX33xHPv2s8tq2c2qxxRbNb7fbMp98NiZvvPlukuRvDz+WJNlzt9/Odv/Jk7/NE/96NptsuF5aLrv0HJ/3+eGvpP362+aqa2+Z5XbTp0/Pq6+/lXXWWr3Scx3XWTMj3x+VCRMmznT/Bg3qZ70O7fOXW+/JrXfcn1Eff5pXX38rBxx2fJos3ij779OzfNspU6emQYPSSsdYpEGDTJ36fV5/4505vj4AAH7dhLYAAFTp2N//Lj/8MC1/uuTPv+g4q626UpLkvf+/sdebb49Mkqy+2sqz3fe990flhx9+yG/+/xg/9dXX32Tsl1+XP6ZOnTrXtX319bhMmTI1yzRfqtJzSy89Y+zT0Z/P8hg3XH1eVl6xdfY9+A9Zac3Ns+7Gv81Lr/wnjz80OG1btyrfbuUV2+TZF17OtGnTysemTp2a5198ZcZ5fkE4DgDAr4vQFgCAKrVt3Sp79twh1918ez6bTXA5K4stukiSlH/9f/z/z1xtuNiis933x21/PMZPtVt7yyy7Uufyx98eeqz8uU036pQpX72Zk487YpbH//bb75Ik9UrrVXqufumMWbHffTdllsdouNiiWa3dSvndAXvk9psvy6V/6p9pP0zLrnsdXmF5hoP33yPvvPtBDj7yxLzx5rt5/T9vZ79DjstnY76YUctszgMAwMJDaAsAwEwdd8wh+eGHaTn/4p+/tu3ESZOTzFgqIUka/f+6tBPmYA3XH4PdH4/xU3cOuiIP3nV9zjn9jz+7th/XnJ06pfIs3e+mzAhR69evvKTBj3744Ydss2OfNGq0WC4575T8dvstc/B+vfLg3ddn5Acf5cLLrivf9qA+u+fYfgdnyJ0PZK0Nts86G+2QkR+Myu+P3D9JsmgVwTQAAAsnoS0AADPVtnWr7LFr91802/Y//79W64ptlkuSrLJSmyTJa/95e7b7rtBmudSpU6fK9V432XC9bNFlg6yz1m9+Vl1JskSTxiktrVc+2/WnRo+eMdZi6WYz3f9fT72Q1994J9tvvXmF8ZVWaJ12K7fN088OrzB++klH56O3nsxjDw7Ki0/em6eG3pnp06cnSVZeofXPvg4AAH5dhLYAAMzSj7Ntf87athMnTsq9DzyaVssuk3arrJAk2a7bZkmSwbffN9v9F110kWyy4Xr511PP55NPq3/N11q1amX11VbO8BGvVXruuRdfTpvWrdLw/2cGV+XzL8YmSaZNm17pue+//yE//DCt0niTxRtnw/U7ZPXVVkmSPPbPp9OyxdJZZeW2P/cyAAD4lRHaAgAwSyu0WS577No9f77ptoz+fOwc7/ftt9+lzyHH5quvx+XY3x+ckpKSJMn6662drbbYODf85c7c+8A/Ku03derUHHvyueU/n/jHQzNt2rT0+d0fytfF/amysrJKY5Mnf5s33x5ZYU3Zmdlph255YfirefGlV8vH3npnZJ7417PZ+bdbV9j2zbdHZtTHn5b/vNL/z4694+4HKmz30suv5+1330/7NVed5bnvuOvBvDD81Rz+u96pVUtrDgDADHVqugAAAIrv2N//LoNuvy9vv/N+Vmu3UqXnP/3s8/KZsxMnTc4bb72bu+59JKPHfJGjDuuTA/fdvcL21191brbbef/s1vuIbLf1Ztlsk/Wz6CKL5N2RH+aOux7IZ2O+yLlnHJsk2ajzurn4vJNz9LFn5jcdt87uu2yfVVZqm6nff5933v0gQ+68P/Xq1U3z5kuVH//54a9kqx32yUl/PGy2NyM7eL9eue7mO9Jj99/lqMP2S926dXLplTemebMlc9RhfSps2379bbPJhh3z6P1/SZKss9bq2aLLBvnLrfdk/IRJ6brZhhk9+vNcee2gNGhQP0f8rnf5vv966vmcdf6V6brZhlmyyeJ57oWXc9Pgu7LVFhtX2A4AAIS2AADM1optl88ePbvnL7feU+XzL7/6Rvr87o8pKSlJw8UWTctll8m23bpkv713TccOa1bafqmmS+SfD9+agdcNzp33PJT+Z12cqVO/z3KtWmS7bTbP4QdXDDEP3q9X1u+4Vi696sYZYfDnX6Ru3bpp27pV9tq9Rw7ar1dW+P81c+dWw4aL5dH7bs4fThyQcy64KtPLpmeTDdfL+Wcdn6WaLjHb/f866MpcdPn1uf2uB/P3of9Kvbp1s2HnDjn1hL5ZZaX/LnnQYpnmqV2rVi667LpMmDgprZdrmdNO7Ju+h+6bOnW05QAA/FdJWVXfJwMA+BV64403cutfrsvvD987DRrUr+lygDn07bff5YLL/5Jee++fVVed9ZITAAC/BhbOAgAAAAAoEKEtAAAAAECBCG0BAAAAAApEaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAIC59sGoj1O6RLvcPPiumi4FAAB+dYS2AABJbh58V0qXaFf+aLTMmvlNx27p+8fTM+bzsTVd3s/yxpvv5oxzLssHoz7+2ccYcuf9ufSqm6qxquq1R5+jUrpEu5xw6p9qupRfhfsfeiyduuyURsusmRXX2CynD7g0P/zww2z3+/SzMdn34D9k9fW2zpLLrZNmrTtmw6675i+33p2ysrIK267cfvMKf9Z++lht3W7z6tIAABYodWq6AACAIul//JFpvXzLfPfdlDz17Iu55vohefjRYXnp3/dnkUUa1HR5c+WNt97NmeddkU02Wi+tl2v5s44x5M6/5T9vvJMjD9mnwvjyrZbNuE9fTt26NddOjh8/MQ888niWX27Z3P7XB3JW/9+npKSkxupZ0D386LDsutdh2XSj9XLROSfltTfezoALBubzsV/l8gtOneW+X371TT75dHR22qFbWrVcJt9//0OGPvFUDjjs+Lz97vs54+R+5duef/YJmTRpcoX9R330afqfdXG6brbhvLg0AIAFjtAWAOAnunXdOB3WXiNJsl/vXbNEk8VzyZU35v6Hhma3nbf/RceePPnbBS74nZmSkpLUr19aozXcff8jmTZteq657Kx0++2++ddTz2eTDder0ZqqUlZWlu++m5IGDerXdCmzdNwp52WN36ySB/56XerUmfFrQqOGi+XcC6/O4Qf3TruV28503zV+s0oevf8vFcYOPXCv7Njrd7nimlty6gl9U7t27STJb7frWmn/AX+6KknSa5df9mcMAODXwvIIAACz0GWT9ZMkH3z43yUGBt9+X9bfbKc0btE+S7ftlL3275ePPv6swn5bdt87a2/QPcNHvJYtttsriy+7Vk4+46LytWAvvOy6XPXnQVll7a5ZfNm1su1O++Wjjz9LWVlZzj7/yrT9zaZp3KJ9dt7z0Hz19TcVjl26RLuccc5llWpduf3mOeCw45LMWO6hV5+jkiRb7bBP+dfP//nks0mS+x4cmt/udnBar7ZxGi69Rtqts2XOPv/KTJs2rcI1PPT3f+bDjz4t33/l9pvPeD1msqbt48Oeyebb7pkmLddOs9Yds/Oeh+aNt96rsM0Z51yW0iXa5d2RH+aAw45Ls9Yds9Ty6+bAw47P5Mnfzulbk1vv/Fu26LJBumy8ftqtvEJuveP+Krd78+2R2aPPUVl2pc5p3KJ9Vl9v65xy5kUVtvnk0zE5+IgTy1+PldfaIof//tRMnTq1Qs3/68dlNX66BMXK7TdPj90Pzt+H/iudN985jVu0z7U33pYkuWnQX9Ptt/uk5cobpOHSa6T9+tvl6utvrbLuhx8dlq7b75Ull1snTZfrkA222CVD7pxxjacPuDSLNls9X4z9qtJ+hxx1cpq17pjvvpuSz0Z/njffHpnvv/9+lq/lG2++mzfeejf779OzPLBNkoP365WysrLcfd8js9x/ZpZfbtlMnvxtpk6d9fmH3Pm3tF6+ZTp3WudnnQcA4NfGTFsAgFkY+f6oJMkSSyyeJDnngoE59exLskuPbdJn710zduxXufLaW7LF9nvluX/encUbNyrf96uvv8kOPQ/Krjttm167dk+zZk3Lnxty598yder3OfTAvfL11+NywWV/zp77HZUum6yfYU8+l9/3PSDvjRyVK6+9JcedfF6uufzsuap7ow065rCD9s4V1/wlx/Y7OO1WXiFJyv/3L7fencUWXSRHHrJvFlts0Twx7JmcNuDSjJ8wMeec/sckybH9fpdx4yfmk09H5/yzjk+SLLroIjM959AnnsoOPQ9Km9Ytc9IfD893332XK6+9JZtts0eeeeKvlZZo2HO/o9N6uWVzxin9MuLl/+T6v9yRpZZaMmefesxsr+/Tz8bkn/96NtddeU6SZLedt8ulV92US847OfXq1Svf7tXX38rm2+6ZunXrZP99emb5Vstm5Acf5YGHH8/pJx1dfqyNttw134ybkP1798wqK7fJp59+nrvueySTv/2uwvHm1NvvfpDeBx6TA/btmf1675qVV2yTJLnmhiFZbZUVs93Wm6dOndp54OHHc+Qxp2X69Ok55IA9y/e/efBdOeiIE7NauxXzx6MOSuPGjfLyK2/k70OfzO67dM8eu/02Z51/Ze64+8EceuBe5ftNnTo1d9/39+zYfavUr1+ak8+4MH+59Z68NeIfs1wiY8SrbyRJOqy1eoXxFss0T8sWS2fEK/+Zo+v+9tvvMmnyt5k4aVL+9e/nc/Pgu7N+x7VmOct4xCv/yZtvv5fjfv+7OToHAMDCQGgL8Ct24403pk+fPnn++eez7rrr1nQ5sEAYN35ixn75db77bkqefnZ4zj7/yjRoUD/bbrVZPvzok5x+zmU57cS+ObbffwOmHttvmfW67JSrrxtcYXz0mC9y+YWn5sB9dy8f+3FG5qefjcnrLzySxo0aJkmmTZ+W8y66Jt9+NyVPP3Zn+WzHsV9+lVvvvD+XXXBqSkvnPDxs27pVNurcIVdc85ds0WWDbLpRpwrP33zNnyoEaQf12T2H9eufq6+/NaedeFRKS+ul62Yb5vKrb84334zLHj13mO05j+9/fpZo0jjDHhmSJZosniTZYbuuWW/THXPGOZfluivPrbD9WmusmqsvO6v85y+//iY33nLnHIW2t/31gZSW1kv3bWbM/N11p21z2oBL89Cjwyp8/f7oY89IWVlZnnnirizXskX5+Fn9f1/+3yefcWFGjxmbJx+9rXxpjCTpf8KRlW6iNafeG/lh7r/j2my1xcYVxv9x/18qvO6HHrhXtt/lgFxy5Y3loe248RPS77iz0nGdNfPo/TdXWIbix3pWbLt81u+4Vm69/f4Koe1Df/9nvv5mXPbYbfbv1099NvrzJMnSzZeq9NzSSy9V/vzsXH71zTnp9AvLf95s08659rJZ/4PDjzOkd9+l+5yWC1AYJSUlOeyww3L55ZfXdCnAr4zlEYAFwpVXXpmSkpJ06tRp9hsz3+27775ZbLHFaroMqBbb7Ngny67UOSus0SV7HdAviy26SG6/+bIs26J57rn/0UyfPj0799gmY7/8uvzRvPlSWbHt8nniX89VOFZpab3ss8dOVZ5np99uXR7YJknHDu2TJL127V7h6+kdO7TP1Knf55PPxlTrdf40OJwwYUZQvVHndTN58rd5652Rc328z0Z/npdffSN799qxPLBNZqx1ukWXDfLwo8Mq7XNgn90r/Lzh+h3y5VffZPz4ibM935A7/5Ztttw0DRvO+LtnpRVaZ521fpMhP1ki4YuxX+VfT72QffbcuUJgm6T8hmXTp0/PfQ8MzXZbb1YhsP3f7eZW6+VbVgpsk4qv+7jxEzL2y6+zyYYd8/4HH2Xc+AlJkqGP/zsTJk7KMUcdWGnd4J/Ws+fuPfLciy/nvf+fDZ7MCEBbLbtM+dq+f77inEz56s3Z3ojuu++mJEmV/zBQv7Q03/7/87PTc+ft8uBd1+fma/+U3f9/fdpvv/tupttPnz49d9z1YNZac7WsusoKc3QOYM7poX+esrKyTJw4+/8vApiXzLQFFgiDBg1K69at89xzz+Xdd9/NiiuuWNMlAb9Sl5x/SlZaoXXq1KmdZks1zSortUmtWjP+nfvdkR+mrKwsv1m3W5X71q1bsbVqsUzzmX61frmWy1T4uXGjGeFjq2WrHv/mm3FJWs319czMf954J/3PviRPDHsm4ydU/MX0x/Bwboz66NMkKV8G4KfarbxCHn3syUyaNLnC8gqt/uc1aLL4jKUlvh43Lo0azfwfgt54672MeOU/2XO33+bdkR+Wj2+y4XoZeN3gjB8/MY0aLZb3P/goSfKbVVea6bG+GPtVxk+YOMttfo7Wy1cdkj71zPCcfu5lefb5EZXW7x03fkIaN2qYkXNQd5LsuuM2OeaEszPkjvtz4h8Py7jxE/Lg35/IkYfsM9dh84/h8JQpUys9992UKWkwhzedW77Vslm+1bJJkt123j6HHHVytt1xv7z63ENVLpEw7N/P5ZPPxuSIQ/aZq3qBOaOHnnOTJ0/OVVddldtuuy0jRozI999/n0UWWSQdO3ZMnz59svfee5f3AwDzg9AWKLz3338/Tz31VO66664cfPDBGTRoUPr371/TZVVp0qRJWXTRRWu6DOAX6LjOGlXOuExmzAosKSnJfbdfk9q1a1d6frH/We+1Qf2Zr+M5s1/8ateuenxOvqY/bdr02W6TJN+MG5+u3fdOw4aL5ZTjj0zbNq1Sv7Q0L73yn5x46p8yffqcHeeXquo1TGZ/rbfecV+S5A8nDsgfThxQ6fm7738k++y58y8v8CdmFoLO7DWvKuR87/1R2XrHfbPKSm1z3pnHpuWyy6Re3bp5+NF/5tKrbprr173J4o2z7VZdcuudM0Lbu+59JFOmTM0eu87d0ghJsszSzZLMWNLjf8P00aO/yLrrVP1nYnZ22qFbrr/5jvzrqeernHk85I6/pVatWtlt5+1+1vGBmVtQeugi9M8vvPBCdtxxx0yePDm77757+vbtmyWWWCKff/55Hn/88Rx22GEZOHBg/vrXv6ZFixazPyBANfDPREDhDRo0KE2aNMl2222XXXbZJYMGDapyu2+++SZHH310WrdundLS0rRs2TK9e/fO2LFjy7f57rvvcuqpp2bllVdO/fr1s8wyy2SnnXbKe+/NuLP5E088kZKSkjzxxBMVjv3BBx+kpKQkN954Y/nYj0sCvPfee9l2223TsGHD7LnnjPUI//Wvf2XXXXfNcsstl9LS0rRq1SpHH310vv228l3R33zzzfTs2TNLLbVUGjRokFVWWSUnnnhikuTxxx9PSUlJ7r777kr7DR48OCUlJXn66adn+xpOnjw5Bx98cJZccsk0atQovXv3ztdff13+/D777JOmTZtWeXfxrbbaKqussspsz/G/Wrdune233z5PPPFE1l133TRo0CBrrLFG+Wt71113ZY011kj9+vXToUOHvPTSSxX2f+WVV7Lvvvumbdu2qV+/fpZeeunst99++fLLLyud68dz1K9fPyussEKuvvrqnHrqqVWGLLfccks6dOiQBg0aZIkllsjuu++ejz76aK6vj4VT29bLpaysLK2Xb5ktumxQ6dGp41rzpY4mizfON+MqzoadOnVqPhvzRYWxmQWN/3zyuXz51Tf58xUDcsTveme7bptliy4bpMlPbqI2u2P8r+Vazfgl9u1336/03FvvjEzTJZvM8iZmc6qsrCxD7vxbumzcKbfecHGlxxq/WSW33vm3JEmb1jNmJr/+xjszPd5STZdIo4aLzXKbJFn8/2cBfzNufIXxH2cYz4kHHn48U6ZMzV8HX5kD990922y5abboskGlGaht56DuH+25e4+88+4HeWH4qxly5/1Za83VstrPmDXcfvV2SZIXR7xWYfzTz8bk409Hp/0aq871MZOUL6tQ1ZIXU6ZMzd33/z2bbLReWizT/GcdH5i5Oemh9c/Jyy+/nC5dumSjjTbKyJEjc8UVV2TPPffMNttsk3322Sc33nhj3nzzzSy66KLp2rVrhR56Zs4888zUqlUrl1122Wy3BZgZoS1QeIMGDcpOO+2UevXqpVevXnnnnXfy/PPPV9hm4sSJ2XjjjXPZZZdlq622yiWXXJLf/e53efPNN/PxxzNu+jNt2rRsv/32Oe2009KhQ4dccMEF6du3b8aNG5fXXnutqlPP1g8//JBu3bqlWbNm+dOf/pSdd54xs+uOO+7I5MmTc8ghh+Syyy5Lt27dctlll6V3794V9n/llVfSqVOnPPbYYznwwANzySWXpEePHrn//hlrMnbp0iWtWrWqsskeNGhQVlhhhXTu3Hm2dR5++OF54403cuqpp6Z3794ZNGhQevToUT6bbe+9986XX36ZRx55pMJ+o0ePzmOPPZa99tqrqsPO1rvvvps99tgj3bt3z4ABA/L111+ne/fuGTRoUI4++ujstddeOe200/Lee++lZ8+eFWaZPfrooxk5cmT69OmTyy67LLvvvnuGDBmSbbfdtsIsvJdeeilbb711vvzyy5x22mnZf//9c/rpp+eee+6pVM9ZZ52V3r17Z6WVVsqFF16Yo446KkOHDs0mm2ySb7755mddIwuXHt23TO3atXPWeVdUmg1aVlaWL7+a/S9y1aFtm1Z58ukXKoz9+abbM23atApjiywyIyT934D3x9m8P72GqVOn5urrb610rkUXWSTj5mCN2WWWbpb2a6yaW4bcUyHYfP0/b+cfj/87W2+5yWyPMSeeenZ4Phz1SXrvsVN2+u3WlR677rhN/vmvZ/PpZ2OyVNMlsvEG6+amQX/NqI8rhqs/XnutWrWyw3Zb5IGHH8+LL71a6Xw/bte2zXJJkn899d/XfdKkybllyD1zXHtVr/u48RNy8+C7KmzXdbON0nCxRXP+RdeUrzX7v/X8aOuuG6fpkk3yp0uuzbB/P59eu1a8mddnoz/Pm2+PrPIf5X5qtVVXyiortc11//M5uub6ISkpKcmOO/x3SZBx4yfkzbdHVlhG44uxX1V53BtvuTMlJSVZq/1qlZ57+NF/5ptx49PLDchgnphdD61/nlHLbrvtll133TWDBw9O48aNk8wIqn/8e3Py5MlZfPHF88ADD6Rp06Y54YQTZnl9J510Uk455ZRcffXVOeKII37WawSQWB4BKLgXX3wxb775Zvm/Um+00UZp2bJlBg0alI4dO5Zvd/755+e1117LXXfdlR133LF8/KSTTir/Bffmm2/O0KFDc+GFF+boo48u3+a444772XcHnzJlSnbdddcMGFDx67nnnntuGjRoUP7zQQcdlBVXXDEnnHBCRo0aleWWm/HL/xFHHJGysrIMHz68fCxJzjnnnCQzZrjttddeufDCCzNu3LjyRvKLL77I3//+9/IZBbNTr169DB06NHXr1k2SLL/88vnjH/+Y+++/PzvssEM233zztGzZMrfccku233778v1uvfXWTJ8+/WeHtm+99Vaeeuqp8sZ4tdVWS7du3XLggQfmzTffLL/mJk2a5OCDD86wYcPSpUuXJMmhhx6a3//+9xWOt/7666dXr1558skns/HGM75m279//9SuXTv//ve/y7+u1rNnz6y6asVZYR9++GH69++fM888s0KzvdNOO2XttdfOlVdeOdsmHFZos1xOO7FvTjr9wnw46pN037ZrGjZcNB98+HHufeDR7N+7Z/odsf88r6PP3rvk8H6nZrfeR2SLzTbMK6+9mUcfezJNl2xSYbv2a7RL7dq1c8El12b8+AkprVcvXTZZP53XWztNFm+cAw49LocetHdKSkoy+Pb7qvy7cO32v8kddz+YP5w4IOuus0YWXXSRbL/15lXWNeC0P2SHngdlk612z7577ZzvvpuSK6+9JY0bNcxJx1bPL6633nF/ateunW222rTK57ffevOccubFuf2uB3PUYX1y4TknZbNt98j6XXbK/vv0TOvlWubDjz7JQ3//Z54fdk+S5PST+uUfjz+Vrt17Z//ePdNu5bYZPeaL/PXeR/L4Q4OyeONG2XKzDbNcyxb53ZEn5q0j9k/tWrVy06C70rRpk0qB8Mx03WzD1KtXNzv1OiQH7LtbJk2anOtuviNLNV0yn43+7yzpRo0Wy/lnHZ/f9T0pG2yxS3bfZfssvnjjvPLam/n2229z3ZXnlm9bt27d7LrTtrnq2kGpXbt2pWUGTj7jwvzl1nvy1oh/zPZmZANO/0N23uPQbLvz/um547Z5/Y13ctWfB6XP3rtUuEnYvX97NAcefkKuvfzs9P7/G+2dc8HAPP3c8Gy1+cZp1XKZfP3NuNx9/9/zwvBXc+hBe2XFtstXOt+td9yf0tJ62XGHrebo9QPm3Jz00PrnGSHupEmTctlll6WkpCQTJ07Mfvvtl7vuuislJSXp1atXWrRokdGjR+fGG2/M5Zdfnk6dOuW8885Lw4b/vZnoj4455phcdNFFueGGG7LPPtbqBn4ZM22BQhs0aFCaN2+ezTbbLMmMJmy33XbLkCFDKswE+utf/5r27dtXaDh/9ONXe//617+madOmVf6L98+9O3iSHHLIIZXGftpwTpo0KWPHjs0GG2yQsrKy8mUAvvjiiwwbNiz77bdfhYbzf+vp3bt3pkyZkjvvvLN87LbbbssPP/wwx2HqQQcdVB7Y/lhznTp18uCDDyaZMdNszz33zH333ZcJE/47c2rQoEHZYIMN0qZN5RsLzYnVVlutwkzgH+9cvPnmm1e45h/HR4787x3rf/oafvfddxk7dmzWX3/9JMnw4cOTzJj98Y9//CM9evSosL7YiiuumG222aZCLXfddVemT5+enj17ZuzYseWPpZdeOiuttFIef/zxn3WNLHz+cNRBue2mS1OrVq2cdf4VOe6U8/K3hx5L1802zPbbVB1mVrf9e/fMMX0PzJNPv5BjTz43H3z4cR686/osukiDCtst3XypXH7Bqfl87Fc5+MiTsveBv88bb72bJZdokrtvvSpLN18qp559SS6+4vps0WWDnH3qMZXO9bv9e2X3XbbPzYPvTu8Dj0m/Y8+caV1bdNkg999xbZZcYvGcfs5luejy67Peuu3z+EOD02YmN+aaG99//33uuvfhdF5v7SzRZPEqt/nNaiun9fItc+sdM2Zcrbl6u/zrkduy0QYdc831Q9Lv+LNy9/1/rxA8L9uief716G3ZaYetMuTO+9Pv+LNyy233ZpONOmaR/1+6oG7durn9L5elbevlctrZl+TKa25Jn713ySEH7DnH9a+yUtvceuMlKSkpyXGnnJdrbhiS/ffpmcMP3rvStn323iV/HXxlGjVcLGf/6aqceOqfMuLl/6TbFpVnLO+1W48kyWabrF++Nu3PsV23zXLbzZfl66/H5ejjzsy9f3s0xx59cC49/5TZ7rvNVpumebOmuWnwX9P3j2fknAsGpl7durn28rNz4YDK/8A4fvzEPPToP7PNlpumcaPKwQfwy8xJD61/njG7d7/99stii824+eWJJ56YoUOH5oILLshtt92WcePGVVjiYM0118wyyyyTZ555psK5y8rKcvjhh+eSSy7JLbfcIrAFqoWZtkBhTZs2LUOGDMlmm22W99//7xqJnTp1ygUXXJChQ4dmq61mzM557733yr9aNTPvvfdeVlllldSpU31/9dWpUyctW1YOIkaNGpVTTjkl9913X6V1r8aNG5fkvwHl6quvPstztGvXLh07dsygQYOy//4zZvANGjQo66+//hzfAXillSqub7jYYotlmWWWyQcffFA+1rt375x77rm5++6707t377z11lt58cUXM3DgwDk6R1X+t5n+caZDq1atqhz/6Wv11Vdf5bTTTsuQIUPy+eefV9j+x9fw888/z7ffflvl6/C/Y++8807KysoqvRY/+mmozcKp9x47lc8anJ0e3bdKj+6znh346P1/qXK89XItM+WrNyuNb7pRpyrHq6qrVq1aOav/73NW/4qz0d9++bFK++/Xe9fs13vXSuOdO62TYX+/rdL4/9aw6KKL5KZr/jTH17H5pp2z+aazXrbl5OOOyMnHVQ4AZvce1K1bN5+++8xMn//RWy/9o8LPq626Um6/edbrCi7XskWFGaxVWbv9b/KvRyu/Zv9bc1Xvw4+233rzKmcqV3XjtJlt+7/q1Zvx99cePSvfgOzPV5yTP19xzmyP8aPfbtc1v92u6yy3qep96rrZhum62YZzfJ5GjRbLuE9fnuPtgTk3pz20/nnGjORjjpnxD5ZlZWX585//nKuuuqp8SYYddtgh7dq1q3Ds5s2b54svKq4hf/PNN2fixIm56qqr0qtXr9m+BgBzQmgLFNZjjz2Wzz77LEOGDMmQIUMqPT9o0KDy0La6zPzu4NOqHC8tLa10B/hp06Zlyy23zFdffZVjjz027dq1y6KLLppPPvkk++6778+6K3vv3r3Tt2/ffPzxx5kyZUqeeeaZXH755XN9nFlZbbXV0qFDh9xyyy3p3bt3brnlltSrVy89e/b82cec2Z3h5+SO8T179sxTTz2VP/zhD1lrrbWy2GKLZfr06dl6661/1ms4ffr0lJSU5KGHHqry/D/OsABY0Fx38+1ZbLFF0mP7LWu6FKAA5ncPvSD3z19++WX5t7W++OKLTJ48ucISbHXq1Mk666xTYZ+PPvooSy65ZIWxDTfcMCNGjMjll1+enj17ZokllpjregH+l9AWKKxBgwalWbNmueKKKyo9d9ddd+Xuu+/OwIED06BBg6ywwgqzvRnCCiuskGeffTbff//9TGdVNmkyYz3I/70p1YcffjjHdb/66qt5++23c9NNN1W4ccKjjz5aYbu2bdsmyRzdxGH33XdPv379cuutt+bbb79N3bp1s9tuu81xTe+880751+OSGTee+Oyzz7LttttW2K53797p169fPvvsswwePDjbbbdd+WsyP3399dcZOnRoTjvttJxyyn+/lvvOOxXvpN6sWbPUr18/7777bqVj/O/YCiuskLKysrRp0yYrr7zyvCmcBcbPXYcPiuRvDz+WN996L9fddEcOOWCPLLroIjVd0jzjzyzMuTntofXPSaNGjcpn8S655JKpW7du3nvvvQr3Rhg5cmT5zN6HHnooX3/9daUbAa+44oo577zz0qVLl2y99dYZOnRolWveAswNa9oChfTtt9/mrrvuyvbbb59ddtml0uPwww/PhAkTct999yVJdt5557z88su5++67Kx3rx1/0dt5554wdO7bKGao/brP88sundu3aGTZsWIXnr7zyyjmu/cdZnD/9BbOsrCyXXHJJhe2WWmqpbLLJJrn++uszatSoKuv5UdOmTbPNNtvklltuyaBBg7L11lunadOmc1zTNddcU+HO4VdddVV++OGHSuu+9urVKyUlJenbt29Gjhz5s29A9ktV9RomycUXX1xpu65du+aee+7Jp5/+90ZA7777bh566KEK2+60006pXbt2TjvttErHLSsry5dfflmNV0BR1a1bNylJpv7kzwMsqPode2bOOPfybL3lJjmliuUmfk2mfv99UjLjxprAzM1ND61/TlZdddU8++yz5TV07949v//97zNs2LC8//776d+/f4YPH54JEybkhhtuSK9evXLyySenUaNGla5hzTXXzIMPPpg33ngj3bt3z7fffjvH1w9QFTNtgUL68YZYO+xQeX2+JFl//fWz1FJLZdCgQdltt93yhz/8IXfeeWd23XXX7LfffunQoUO++uqr3HfffRk4cGDat2+f3r175+abb06/fv3y3HPPZeONN86kSZPyj3/8I4ceemh++9vfpnHjxtl1113L7yC7wgor5G9/+1ulNVVnpV27dllhhRVyzDHH5JNPPkmjRo3y17/+tdLaXEly6aWXZqONNso666yTgw46KG3atMkHH3yQBx54ICNGjKiwbe/evbPLLrskSc4444w5fzGTTJ06NVtssUV69uyZt956K1deeWU22mijSq/vUkstla233jp33HFHFl988Wy33XYzOeK81ahRo2yyySY577zz8v3332fZZZfN3//+9wrrsv3o1FNPzd///vdsuOGGOeSQQzJt2rRcfvnlWX311Su8hiussELOPPPMHH/88fnggw/So0ePNGzYMO+//37uvvvuHHTQQeVrmvHr1axZs5SU1MlHH3+WxRtX/oULFiSzWjv312bUR5+mpKRullpqqZouBQptbnrowYMHL/T98/bbb5/rrrsuhx12WEpKSnLRRRdlq622yqabbppkRhB70EEH5eqrr86wYcNy+umn58gjj5zpday//vq59957s+2222aXXXbJPffc474JwM8mtAUKadCgQalfv3623LLq9flq1aqV7bbbLoMGDcqXX36ZJZdcMv/617/Sv3//3H333bnpppvSrFmzbLHFFuU3Oqhdu3YefPDBnHXWWRk8eHD++te/Zskll8xGG22UNdZYo/zYl112Wb7//vsMHDgwpaWl6dmzZ84///zZ3vDgR3Xr1s3999+fI488MgMGDEj9+vWz44475vDDD0/79u0rbNu+ffs888wzOfnkk3PVVVflu+++y/LLL1/lOrLdu3dPkyZNMn369Jk24jNz+eWXZ9CgQTnllFPy/fffp1evXrn00kurXIOsd+/e+dvf/paePXumtLR0rs5TnQYPHpwjjjgiV1xxRcrKyrLVVlvloYceKl937EcdOnTIQw89lGOOOSYnn3xyWrVqldNPPz1vvPFG3nyz4k2SjjvuuKy88sq56KKLctpppyWZcVO0rbbaaq5fUxZMjRo1SqvlV8jwl17Pau1WnOn6ykBxTJs2LS+N+E9aLd+2ytltwH/NTQ89ZcqUhb5/PvjggzNgwIBceuml6du3b5Zbbrm89tprGT58eOrUqZO11lorH330UQ477LCsttpqc9Q3bL755rn99tuz8847Z++9987gwYMrreELMCdKyiwQBbBA+OGHH9KiRYt0794911133Tw7z7333psePXpk2LBh2XjjjefZeea1Hj165PXXX6+0Di6MHDkyg/5yY1o2Xywbdu6Q5Vq28MsUFND06dMz6uNP8++nX8zHYyZmz733LV/PEmBOzEn/fPvtt2fPPffMZZddlt/97ndVbjNq1Kh8/PHH2WCDDeZluQAVmGkLsIC455578sUXX1S4OcO8cO2116Zt27bZaKON5ul5qtO3336bBg0alP/8zjvv5MEHH8w+++xTg1VRVG3btk2vPXvnvnvvzi23P5K6tZPSenWTmdz9GqgBZWWZMvX7fD8tWXyJpdJrz94CW2CuzUn/3LNnz4wbNy6HHnpoBg0alAMOOCAdO3bMYostlg8//DD33ntvrr766myzzTZCW2C+MtMWoOCeffbZvPLKKznjjDPStGnTDB8+fJ6cZ8iQIXnllVcyYMCAXHLJJbNcr6tolllmmey774wZWB9++GGuuuqqTJkyJS+99FJWWmmlmi6PgiorK8snn3ySjz/+OFOmTHF3eiiQkpKSlJaWpmXLlll22WWrXM4HYGZ+Tv/86quv5uSTT87DDz+cKVOmlI+vvPLK+f3vf58DDjjAN3OA+UpoC1Bw++67b2655ZastdZaufHGG+d4bbC5VVJSksUWWyy77bZbBg4cmDp1FpwvY/Tp0yePP/54Ro8endLS0nTu3Dlnn3121llnnZouDQCA+eyX9M+TJk3K22+/nYkTJ6Zly5Zp06bNPKwUYOaEtgAAAAAABWJuPwAAAABAgSw4332tJtOnT8+nn36ahg0bWhsLAGABU1ZWlgkTJqRFixYL9dqCeloAgAXTnPazC11o++mnn6ZVq1Y1XQYAAL/ARx99lJYtW9Z0GTVGTwsAsGCbXT+70IW2DRs2TDLjhWnUqFENVwMAwNwYP358WrVqVd7TLaz0tAAAC6Y57WcXutD2x6+PNWrUSIMLALCAWtiXBNDTAgAs2GbXzy68C4EBAAAAABSQ0BYAAAAAoECEtgAAAAAABSK0BQAAAAAoEKEtAAAAAECBCG0BAAAAAApEaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKRGgLAAAAAFAgQlsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQIS2AAAAAAAFIrQFAAAAACgQoS0AAAAAQIEIbQEAAAAACkRoCwAAAABQIEJbAAAAAIACEdoCAAAAABSI0BYAAAAAoECEtgAAAAAABSK0BQAAAAAoEKEtAAAAAECBCG0BAAAAAApEaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitAUAAAAAKJAaDW2HDRuW7t27p0WLFikpKck999wz232eeOKJrLPOOiktLc2KK66YG2+8cZ7XCQAAM6OnBQCgutVoaDtp0qS0b98+V1xxxRxt//7772e77bbLZpttlhEjRuSoo47KAQcckEceeWQeVwoAAFXT0wIAUN3q1OTJt9lmm2yzzTZzvP3AgQPTpk2bXHDBBUmSVVddNU8++WQuuuiidOvWrcp9pkyZkilTppT/PH78+F9WNAAA/ISeFgCA6rZArWn79NNPp2vXrhXGunXrlqeffnqm+wwYMCCNGzcuf7Rq1WpelwkAADOlpwUAYHYWqNB29OjRad68eYWx5s2bZ/z48fn222+r3Of444/PuHHjyh8fffTR/CgVAACqpKcFAGB2anR5hPmhtLQ0paWlNV0GAAD8bHpaAICFywI103bppZfOmDFjKoyNGTMmjRo1SoMGDWqoKgAAmHN6WgAAZmeBCm07d+6coUOHVhh79NFH07lz5xqqCAAA5o6eFgCA2anR0HbixIkZMWJERowYkSR5//33M2LEiIwaNSrJjLW7evfuXb797373u4wcOTJ//OMf8+abb+bKK6/M7bffnqOPPromygcAAD0tAADVrkZD2xdeeCFrr7121l577SRJv379svbaa+eUU05Jknz22WflzW6StGnTJg888EAeffTRtG/fPhdccEH+/Oc/p1u3bjVSPwAA6GkBAKhuJWVlZWU1XcT8NH78+DRu3Djjxo1Lo0aNarocAADmgl5uBq8DAMCCaU77uAVqTVsAAAAAgF87oS0AAAAAQIEIbQEAAAAACkRoCwAAAABQIEJbAAAAAIACEdoCAAAAABSI0BYAAAAAoECEtgAAAAAABSK0BQAAAAAoEKEtAAAAAECBCG0BAAAAAApEaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKRGgLAAAAAFAgQlsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQIS2AAAAAAAFIrQFAAAAACgQoS0AAAAAQIEIbQEAAAAACkRoCwAAAABQIEJbAAAAAIACEdoCAAAAABSI0BYAAAAAoECEtgAAAAAABSK0BQAAAAAoEKEtAAAAAECBCG0BAAAAAApEaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKpE5NFwAAAAAAvzaXfH1JTZfAXOrbpG9Nl1DOTFsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQIS2AAAAAAAFIrQFAAAAACgQoS0AAAAAQIEIbQEAAAAACkRoCwAAAABQIEJbAAAAAIACEdoCAAAAABSI0BYAAAAAoECEtgAAAAAABSK0BQAAAAAoEKEtAAAAAECBCG0BAAAAAApEaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKRGgLAAAAAFAgQlsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQIS2AAAAAAAFIrQFAAAAACiQGg9tr7jiirRu3Tr169dPp06d8txzz81y+4svvjirrLJKGjRokFatWuXoo4/Od999N5+qBQCAyvS0AABUpxoNbW+77bb069cv/fv3z/Dhw9O+fft069Ytn3/+eZXbDx48OMcdd1z69++fN954I9ddd11uu+22nHDCCfO5cgAAmEFPCwBAdavR0PbCCy/MgQcemD59+mS11VbLwIEDs8gii+T666+vcvunnnoqG264YfbYY4+0bt06W221VXr16jXLmQxTpkzJ+PHjKzwAAKC66GkBAKhuNRbaTp06NS+++GK6du3632Jq1UrXrl3z9NNPV7nPBhtskBdffLG8oR05cmQefPDBbLvttjM9z4ABA9K4cePyR6tWrar3QgAAWGjpaQEAmBfq1NSJx44dm2nTpqV58+YVxps3b54333yzyn322GOPjB07NhtttFHKysryww8/5He/+90sv0p2/PHHp1+/fuU/jx8/XpMLAEC10NMCADAv1PiNyObGE088kbPPPjtXXnllhg8fnrvuuisPPPBAzjjjjJnuU1pamkaNGlV4AABATdHTAgAwOzU207Zp06apXbt2xowZU2F8zJgxWXrppavc5+STT87ee++dAw44IEmyxhprZNKkSTnooINy4oknplatBSqDBgBgAaenBQBgXqixjrBevXrp0KFDhg4dWj42ffr0DB06NJ07d65yn8mTJ1dqYmvXrp0kKSsrm3fFAgBAFfS0AADMCzU20zZJ+vXrl3322Sfrrrtu1ltvvVx88cWZNGlS+vTpkyTp3bt3ll122QwYMCBJ0r1791x44YVZe+2106lTp7z77rs5+eST07179/JGFwAA5ic9LQAA1a1GQ9vddtstX3zxRU455ZSMHj06a621Vh5++OHyGzmMGjWqwiyEk046KSUlJTnppJPyySefZKmllkr37t1z1lln1dQlAACwkNPTAgBQ3UrKFrLvYI0fPz6NGzfOuHHj3MABAGABo5ebwesAAMV3ydeX1HQJzKW+TfrO83PMaR/nLgcAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKRGgLAAAAAFAgQlsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQOrUdAEAAAAAC6JLvr6kpktgLvVt0remS4A5YqYtAAAAAECBCG0BAAAAAApEaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKRGgLAAAAAFAgQlsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQIS2AAAAAAAFIrQFAAAAACgQoS0AAAAAQIEIbQEAAAAACkRoCwAAAABQIEJbAAAAAIACEdoCAAAAABSI0BYAAAAAoECEtgAAAAAABSK0BQAAAAAoEKEtAAAAAECBCG0BAAAAAApEaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKRGgLAAAAAFAgQlsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQIS2AAAAAAAFIrQFAAAAACgQoS0AAAAAQIEIbQEAAAAACkRoCwAAAABQIEJbAAAAAIACEdoCAAAAABSI0BYAAAAAoECEtgAAAAAABSK0BQAAAAAoEKEtAAAAAECBCG0BAAAAAApEaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKRGgLAAAAAFAgQlsAAAAAgAKZ69C2devWOf300zNq1Kh5UQ8AAMxzeloAAIpsrkPbo446KnfddVfatm2bLbfcMkOGDMmUKVPmRW0AADBP6GkBACiynxXajhgxIs8991xWXXXVHHHEEVlmmWVy+OGHZ/jw4fOiRgAAqFZ6WgAAiuxnr2m7zjrr5NJLL82nn36a/v37589//nM6duyYtdZaK9dff33Kysqqs04AAKh2eloAAIroZ4e233//fW6//fbssMMO+f3vf5911103f/7zn7PzzjvnhBNOyJ577jlHx7niiivSunXr1K9fP506dcpzzz03y+2/+eabHHbYYVlmmWVSWlqalVdeOQ8++ODPvQwAABZieloAAIqoztzuMHz48Nxwww259dZbU6tWrfTu3TsXXXRR2rVrV77NjjvumI4dO872WLfddlv69euXgQMHplOnTrn44ovTrVu3vPXWW2nWrFml7adOnZott9wyzZo1y5133plll102H374YRZffPG5vQwAABZielpgTlzy9SU1XQJzqW+TvjVdAkC1mOvQtmPHjtlyyy1z1VVXpUePHqlbt26lbdq0aZPdd999tse68MILc+CBB6ZPnz5JkoEDB+aBBx7I9ddfn+OOO67S9tdff32++uqrPPXUU+Xnbd269dxeAgAACzk9LQAARTbXyyOMHDkyDz/8cHbdddcqm9skWXTRRXPDDTfM8jhTp07Niy++mK5du/63mFq10rVr1zz99NNV7nPfffelc+fOOeyww9K8efOsvvrqOfvsszNt2rSZnmfKlCkZP358hQcAAAs3PS0AAEU216Ht559/nmeffbbS+LPPPpsXXnhhjo8zduzYTJs2Lc2bN68w3rx584wePbrKfUaOHJk777wz06ZNy4MPPpiTTz45F1xwQc4888yZnmfAgAFp3Lhx+aNVq1ZzXCMAAL9OeloAAIpsrkPbww47LB999FGl8U8++SSHHXZYtRQ1M9OnT0+zZs1yzTXXpEOHDtltt91y4oknZuDAgTPd5/jjj8+4cePKH1XVDgDAwkVPCwBAkc31mrb/+c9/ss4661QaX3vttfOf//xnjo/TtGnT1K5dO2PGjKkwPmbMmCy99NJV7rPMMsukbt26qV27dvnYqquumtGjR2fq1KmpV69epX1KS0tTWlo6x3UBAPDrp6cFAKDI5nqmbWlpaaWmNEk+++yz1Kkz5xlwvXr10qFDhwwdOrR8bPr06Rk6dGg6d+5c5T4bbrhh3n333UyfPr187O23384yyyxTZXMLAABV0dMCAFBkcx3abrXVVuVfz/rRN998kxNOOCFbbrnlXB2rX79+ufbaa3PTTTfljTfeyCGHHJJJkyaV33m3d+/eOf7448u3P+SQQ/LVV1+lb9++efvtt/PAAw/k7LPPnudfYQMA4NdFTwsAQJHN9fIIf/rTn7LJJptk+eWXz9prr50kGTFiRJo3b56//OUvc3Ws3XbbLV988UVOOeWUjB49OmuttVYefvjh8hs5jBo1KrVq/TdXbtWqVR555JEcffTRWXPNNbPsssumb9++OfbYY+f2MgAAWIjpaQEAKLKSsrKysrndadKkSRk0aFBefvnlNGjQIGuuuWZ69eqVunXrzosaq9X48ePTuHHjjBs3Lo0aNarpcgAAmAvV2cvpaYHZueTrS2q6BOZS3yZ95+v5fEYWPPPzM+LzseCZH5+POe3j5nqmbZIsuuiiOeigg352cQAAUNP0tAAAFNXPCm2TGXfcHTVqVKZOnVphfIcddvjFRQEAwPygpwUAoIjmOrQdOXJkdtxxx7z66qspKSnJj6srlJSUJEmmTZtWvRUCAEA109MCAFBktWa/SUV9+/ZNmzZt8vnnn2eRRRbJ66+/nmHDhmXdddfNE088MQ9KBACA6qWnBQCgyOZ6pu3TTz+dxx57LE2bNk2tWrVSq1atbLTRRhkwYECOPPLIvPTSS/OiTgAAqDZ6WgAAimyuZ9pOmzYtDRs2TJI0bdo0n376aZJk+eWXz1tvvVW91QEAwDygpwUAoMjmeqbt6quvnpdffjlt2rRJp06dct5556VevXq55ppr0rZt23lRIwAAVCs9LQAARTbXoe1JJ52USZMmJUlOP/30bL/99tl4442z5JJL5rbbbqv2AgEAoLrpaQEAKLK5Dm27detW/t8rrrhi3nzzzXz11Vdp0qRJ+d12AQCgyPS0AAAU2Vytafv999+nTp06ee211yqML7HEEppbAAAWCHpaAACKbq5C27p162a55ZbLtGnT5lU9AAAwT+lpAQAourkKbZPkxBNPzAknnJCvvvpqXtQDAADznJ4WAIAim+s1bS+//PK8++67adGiRZZffvksuuiiFZ4fPnx4tRUHAADzgp4WAIAim+vQtkePHvOgDAAAmH/0tAAAFNlch7b9+/efF3UAAMB8o6cFAKDI5npNWwAAAAAA5p25nmlbq1atlJSUzPR5d+EFAKDo9LQAABTZXIe2d999d4Wfv//++7z00ku56aabctppp1VbYQAAMK/oaQEAKLK5Dm1/+9vfVhrbZZdd8pvf/Ca33XZb9t9//2opDAAA5hU9LQAARVZta9quv/76GTp0aHUdDgAA5js9LQAARVAtoe23336bSy+9NMsuu2x1HA4AAOY7PS0AAEUx18sjNGnSpMJNG8rKyjJhwoQsssgiueWWW6q1OAAAmBf0tAAAFNlch7YXXXRRhQa3Vq1aWWqppdKpU6c0adKkWosDAIB5QU8LAECRzXVou++++86DMgAAYP7R0wIAUGRzvabtDTfckDvuuKPS+B133JGbbrqpWooCAIB5SU8LAECRzXVoO2DAgDRt2rTSeLNmzXL22WdXS1EAADAv6WkBACiyuQ5tR40alTZt2lQaX3755TNq1KhqKQoAAOYlPS0AAEU212vaNmvWLK+88kpat25dYfzll1/OkksuWV11AQDMsXNeGlvTJSy0jlu78mzVBYGeFgCAIpvrmba9evXKkUcemccffzzTpk3LtGnT8thjj6Vv377Zfffd50WNAABQrfS0AAAU2VzPtD3jjDPywQcfZIsttkidOjN2nz59enr37m39LwAAFgh6WgAAimyuQ9t69erltttuy5lnnpkRI0akQYMGWWONNbL88svPi/oAAKDa6WkBACiyuQ5tf7TSSitlpZVWqs5aAABgvtLTAgBQRHO9pu3OO++cc889t9L4eeedl1133bVaigIAgHlJTwsAQJHNdWg7bNiwbLvttpXGt9lmmwwbNqxaigIAgHlJTwsAQJHNdWg7ceLE1KtXr9J43bp1M378+GopCgAA5iU9LQAARTbXoe0aa6yR2267rdL4kCFDstpqq1VLUQAAMC/paQEAKLK5vhHZySefnJ122invvfdeNt988yTJ0KFDM3jw4Nx5553VXiAAAFQ3PS0AAEU216Ft9+7dc8899+Tss8/OnXfemQYNGqR9+/Z57LHHssQSS8yLGgEAoFrpaQEAKLK5Dm2TZLvttst2222XJBk/fnxuvfXWHHPMMXnxxRczbdq0ai0QAADmBT0tAABFNddr2v5o2LBh2WeffdKiRYtccMEF2XzzzfPMM89UZ20AADBP6WkBACiiuZppO3r06Nx444257rrrMn78+PTs2TNTpkzJPffc44YNAAAsEPS0AAAU3RzPtO3evXtWWWWVvPLKK7n44ovz6aef5rLLLpuXtQEAQLXS0wIAsCCY45m2Dz30UI488sgccsghWWmlleZlTQAAME/oaQEAWBDM8UzbJ598MhMmTEiHDh3SqVOnXH755Rk7duy8rA0AAKqVnhYAgAXBHIe266+/fq699tp89tlnOfjggzNkyJC0aNEi06dPz6OPPpoJEybMyzoBAOAX09MCALAgmOPQ9keLLrpo9ttvvzz55JN59dVX8/vf/z7nnHNOmjVrlh122GFe1AgAANVKTwsAQJHNdWj7U6usskrOO++8fPzxx7n11lurqyYAAJhv9LQAABTNLwptf1S7du306NEj9913X3UcDgAA5js9LQAARVEtoS0AAAAAANVDaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKRGgLAAAAAFAgQlsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQIS2AAAAAAAFIrQFAAAAACgQoS0AAAAAQIEIbQEAAAAACkRoCwAAAABQIEJbAAAAAIACEdoCAAAAABSI0BYAAAAAoECEtgAAAAAABSK0BQAAAAAoEKEtAAAAAECBCG0BAAAAAAqkEKHtFVdckdatW6d+/frp1KlTnnvuuTnab8iQISkpKUmPHj3mbYEAADAL+lkAAKpTjYe2t912W/r165f+/ftn+PDhad++fbp165bPP/98lvt98MEHOeaYY7LxxhvPp0oBAKAy/SwAANWtxkPbCy+8MAceeGD69OmT1VZbLQMHDswiiyyS66+/fqb7TJs2LXvuuWdOO+20tG3bdj5WCwAAFelnAQCobjUa2k6dOjUvvvhiunbtWj5Wq1atdO3aNU8//fRM9zv99NPTrFmz7L///rM9x5QpUzJ+/PgKDwAAqA7zo59N9LQAAAubGg1tx44dm2nTpqV58+YVxps3b57Ro0dXuc+TTz6Z6667Ltdee+0cnWPAgAFp3Lhx+aNVq1a/uG4AAEjmTz+b6GkBABY2dWq6gLkxYcKE7L333rn22mvTtGnTOdrn+OOPT79+/cp/Hj9+fI00uee8NHa+n5MZjlt7zj4rAADz2s/pZ5Pi9LS/Rpd8fUlNl8Bc6tukb02XAADzXI2Gtk2bNk3t2rUzZsyYCuNjxozJ0ksvXWn79957Lx988EG6d+9ePjZ9+vQkSZ06dfLWW29lhRVWqLBPaWlpSktL50H1AAAs7OZHP5voaQEAFjY1ujxCvXr10qFDhwwdOrR8bPr06Rk6dGg6d+5caft27drl1VdfzYgRI8ofO+ywQzbbbLOMGDHCbAMAAOYr/SwAAPNCjS+P0K9fv+yzzz5Zd911s9566+Xiiy/OpEmT0qdPnyRJ7969s+yyy2bAgAGpX79+Vl999Qr7L7744klSaRwAAOYH/SwAANWtxkPb3XbbLV988UVOOeWUjB49OmuttVYefvjh8ps5jBo1KrVq1eiEYAAAmCn9LAAA1a3GQ9skOfzww3P44YdX+dwTTzwxy31vvPHG6i8IAADmgn4WAIDq5J/8AQAAAAAKRGgLAAAAAFAgQlsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQIS2AAAAAAAFIrQFAAAAACiQOjVdAADMT+e8NLamS1goHbd205ouAQAAYIFhpi0AAAAAQIEIbQEAAAAACkRoCwAAAABQIEJbAAAAAIACEdoCAAAAABSI0BYAAAAAoECEtgAAAAAABSK0BQAAAAAoEKEtAAAAAECBCG0BAAAAAApEaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKRGgLAAAAAFAgQlsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQIS2AAAAAAAFIrQFAAAAACiQOjVdAEARnfPS2JouYaF13NpNa7oEAAAAqFFm2gIAAAAAFIjQFgAAAACgQIS2AAAAAAAFIrQFAAAAACgQoS0AAAAAQIEIbQEAAAAACkRoCwAAAABQIHVqugBYkJ3z0tiaLmGhddzaTWu6BAAAAIB5wkxbAAAAAIACEdoCAAAAABSI0BYAAAAAoECEtgAAAAAABSK0BQAAAAAoEKEtAAAAAECBCG0BAAAAAApEaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKRGgLAAAAAFAgQlsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQIS2AAAAAAAFIrQFAAAAACgQoS0AAAAAQIEIbQEAAAAACkRoCwAAAABQIEJbAAAAAIACEdoCAAAAABSI0BYAAAAAoECEtgAAAAAABSK0BQAAAAAoEKEtAAAAAECBCG0BAAAAAApEaAsAAAAAUCCFCG2vuOKKtG7dOvXr10+nTp3y3HPPzXTba6+9NhtvvHGaNGmSJk2apGvXrrPcHgAA5jX9LAAA1anGQ9vbbrst/fr1S//+/TN8+PC0b98+3bp1y+eff17l9k888UR69eqVxx9/PE8//XRatWqVrbbaKp988sl8rhwAAPSzAABUvxoPbS+88MIceOCB6dOnT1ZbbbUMHDgwiyyySK6//voqtx80aFAOPfTQrLXWWmnXrl3+/Oc/Z/r06Rk6dOh8rhwAAPSzAABUvxoNbadOnZoXX3wxXbt2LR+rVatWunbtmqeffnqOjjF58uR8//33WWKJJap8fsqUKRk/fnyFBwAAVIf50c8meloAgIVNjYa2Y8eOzbRp09K8efMK482bN8/o0aPn6BjHHntsWrRoUaFR/qkBAwakcePG5Y9WrVr94roBACCZP/1soqcFAFjY1PjyCL/EOeeckyFDhuTuu+9O/fr1q9zm+OOPz7hx48ofH3300XyuEgAAqjYn/WyipwUAWNjUqcmTN23aNLVr186YMWMqjI8ZMyZLL730LPf905/+lHPOOSf/+Mc/suaaa850u9LS0pSWllZLvQAA8FPzo59N9LQAAAubGp1pW69evXTo0KHCTRd+vAlD586dZ7rfeeedlzPOOCMPP/xw1l133flRKgAAVKKfBQBgXqjRmbZJ0q9fv+yzzz5Zd911s9566+Xiiy/OpEmT0qdPnyRJ7969s+yyy2bAgAFJknPPPTennHJKBg8enNatW5evFbbYYotlscUWq7HrAABg4aSfBQCgutV4aLvbbrvliy++yCmnnJLRo0dnrbXWysMPP1x+M4dRo0alVq3/Tgi+6qqrMnXq1Oyyyy4VjtO/f/+ceuqp87N0AADQzwIAUO1qPLRNksMPPzyHH354lc898cQTFX7+4IMP5n1BAAAwF/SzAABUpxpd0xYAAAAAgIqEtgAAAAAABSK0BQAAAAAoEKEtAAAAAECBCG0BAAAAAApEaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKRGgLAAAAAFAgQlsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQIS2AAAAAAAFIrQFAAAAACgQoS0AAAAAQIEIbQEAAAAACkRoCwAAAABQIEJbAAAAAIACEdoCAAAAABSI0BYAAAAAoECEtgAAAAAABSK0BQAAAAAoEKEtAAAAAECBCG0BAAAAAApEaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKRGgLAAAAAFAgQlsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQIS2AAAAAAAFIrQFAAAAACgQoS0AAAAAQIEIbQEAAAAACkRoCwAAAABQIEJbAAAAAIACEdoCAAAAABSI0BYAAAAAoECEtgAAAAAABSK0BQAAAAAoEKEtAAAAAECBCG0BAAAAAApEaAsAAAAAUCBCWwAAAACAAhHaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKRGgLAAAAAFAgQlsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQIS2AAAAAAAFIrQFAAAAACgQoS0AAAAAQIEIbQEAAAAACkRoCwAAAABQIEJbAAAAAIACEdoCAAAAABSI0BYAAAAAoEAKEdpeccUVad26derXr59OnTrlueeem+X2d9xxR9q1a5f69etnjTXWyIMPPjifKgUAgMr0swAAVKcaD21vu+229OvXL/3798/w4cPTvn37dOvWLZ9//nmV2z/11FPp1atX9t9//7z00kvp0aNHevTokddee20+Vw4AAPpZAACqX42HthdeeGEOPPDA9OnTJ6uttloGDhyYRRZZJNdff32V219yySXZeuut84c//CGrrrpqzjjjjKyzzjq5/PLL53PlAACgnwUAoPrVqcmTT506NS+++GKOP/748rFatWqla9euefrpp6vc5+mnn06/fv0qjHXr1i333HNPldtPmTIlU6ZMKf953LhxSZLx48f/wurnzncTJ8zX8/Ff48fXm2fH9r7WnHn5vibe25rkvf118r7+es3r97by+Wb0cGVlZfP1vDMzP/rZpDg97a/Rd+O/q+kSmEvja8+/z73Px4Jnfn4+Ep+RBZG/Q5iV+fH5mNN+tkZD27Fjx2batGlp3rx5hfHmzZvnzTffrHKf0aNHV7n96NGjq9x+wIABOe200yqNt2rV6mdWzYKm8rvPr4H39dfLe/vr5H399aqp93bChAlp3LhxDZ39v+ZHP5voaeGnjstxNV0CBebzwez4jDAr8/PzMbt+tkZD2/nh+OOPrzCTYfr06fnqq6+y5JJLpqSkpAYrWzCMHz8+rVq1ykcffZRGjRrVdDlUI+/tr5P39dfLe/vr5b2dO2VlZZkwYUJatGhR06XMV//b037zzTdZfvnlM2rUqEKE1xSLv1eYHZ8RZsXng9nxGfll5rSfrdHQtmnTpqldu3bGjBlTYXzMmDFZeumlq9xn6aWXnqvtS0tLU1paWmFs8cUX//lFL6QaNWrkD+KvlPf218n7+uvlvf318t7OuSKFlPOjn02q7mmTGa+Fzw0z4+8VZsdnhFnx+WB2fEZ+vjnpZ2v0RmT16tVLhw4dMnTo0PKx6dOnZ+jQoencuXOV+3Tu3LnC9kny6KOPznR7AACYV/SzAADMCzW+PEK/fv2yzz77ZN111816662Xiy++OJMmTUqfPn2SJL17986yyy6bAQMGJEn69u2bTTfdNBdccEG22267DBkyJC+88EKuueaamrwMAAAWUvpZAACqW42Htrvttlu++OKLnHLKKRk9enTWWmutPPzww+U3Zxg1alRq1frvhOANNtgggwcPzkknnZQTTjghK620Uu65556svvrqNXUJv2qlpaXp379/lV/HY8Hmvf118r7+enlvf728twu+muhnfW6YFZ8PZsdnhFnx+WB2fEbmj5KysrKymi4CAAAAAIAZanRNWwAAAAAAKhLaAgAAAAAUiNAWAAAAAKBAhLYAAAAAAAUitKVKw4YNS/fu3dOiRYuUlJTknnvuqemSqAYDBgxIx44d07BhwzRr1iw9evTIW2+9VdNlUQ2uuuqqrLnmmmnUqFEaNWqUzp0756GHHqrpsqhm55xzTkpKSnLUUUfVdCn8QqeeempKSkoqPNq1a1fTZbEA0KMxK3o9ZkW/yNzQd/K/9K/zn9CWKk2aNCnt27fPFVdcUdOlUI3++c9/5rDDDsszzzyTRx99NN9//3222mqrTJo0qaZL4xdq2bJlzjnnnLz44ot54YUXsvnmm+e3v/1tXn/99ZoujWry/PPP5+qrr86aa65Z06VQTX7zm9/ks88+K388+eSTNV0SCwA9GrOi12NW9IvMKX0nM6N/nb/q1HQBFNM222yTbbbZpqbLoJo9/PDDFX6+8cYb06xZs7z44ovZZJNNaqgqqkP37t0r/HzWWWflqquuyjPPPJPf/OY3NVQV1WXixInZc889c+211+bMM8+s6XKoJnXq1MnSSy9d02WwgNGjMSt6PWZFv8ic0HcyK/rX+ctMW1iIjRs3LkmyxBJL1HAlVKdp06ZlyJAhmTRpUjp37lzT5VANDjvssGy33Xbp2rVrTZdCNXrnnXfSokWLtG3bNnvuuWdGjRpV0yUBvzJ6PWZGv8jM6DuZFf3r/GWmLSykpk+fnqOOOiobbrhhVl999Zouh2rw6quvpnPnzvnuu++y2GKL5e67785qq61W02XxCw0ZMiTDhw/P888/X9OlUI06deqUG2+8Maussko+++yznHbaadl4443z2muvpWHDhjVdHvAroNejKvpFZkXfyazoX+c/oS0spA477LC89tpr1qD5FVlllVUyYsSIjBs3LnfeeWf22Wef/POf/9SIL8A++uij9O3bN48++mjq169f0+VQjX769fY111wznTp1yvLLL5/bb789+++/fw1WBvxa6PWoin6RmdF3Mjv61/lPaAsLocMPPzx/+9vfMmzYsLRs2bKmy6Ga1KtXLyuuuGKSpEOHDnn++edzySWX5Oqrr67hyvi5XnzxxXz++edZZ511ysemTZuWYcOG5fLLL8+UKVNSu3btGqyQ6rL44otn5ZVXzrvvvlvTpQC/Ano9Zka/yMzoO5lb+td5T2gLC5GysrIcccQRufvuu/PEE0+kTZs2NV0S89D06dMzZcqUmi6DX2CLLbbIq6++WmGsT58+adeuXY499liN86/IxIkT895772Xvvfeu6VKABZhej7mlX+RH+k7mlv513hPaUqWJEydW+NeS999/PyNGjMgSSyyR5ZZbrgYr45c47LDDMnjw4Nx7771p2LBhRo8enSRp3LhxGjRoUMPV8Uscf/zx2WabbbLccstlwoQJGTx4cJ544ok88sgjNV0av0DDhg0rrUO46KKLZskll7Q+4QLumGOOSffu3bP88svn008/Tf/+/VO7du306tWrpkuj4PRozIpej1nRLzIr+k5mR/86/wltqdILL7yQzTbbrPznfv36JUn22Wef3HjjjTVUFb/UVVddlSTp0qVLhfEbbrgh++677/wviGrz+eefp3fv3vnss8/SuHHjrLnmmnnkkUey5ZZb1nRpQBU+/vjj9OrVK19++WWWWmqpbLTRRnnmmWey1FJL1XRpFJwejVnR6zEr+kXgl9C/zn8lZWVlZTVdBAAAAAAAM9Sq6QIAAAAAAPgvoS0AAAAAQIEIbQEAAAAACkRoCwAAAABQIEJbAAAAAIACEdoCAAAAABSI0BYAAAAAoECEtgAAAAAABSK0BQAAAGpcly5dctRRR9V0GQCFILQFqGH77rtvevToUdNlAAAAAAUhtAUAAAAAKBChLUCBdOnSJUcccUSOOuqoNGnSJM2bN8+1116bSZMmpU+fPmnYsGFWXHHFPPTQQ+X7TJs2Lfvvv3/atGmTBg0aZJVVVskll1xS4bg//PBDjjzyyCy++OJZcsklc+yxx2afffapMMN3+vTpGTBgQPlx2rdvnzvvvHN+XToAAFTwwAMPpHHjxhk0aFBNlwIw3wltAQrmpptuStOmTfPcc8/liCOOyCGHHJJdd901G2ywQYYPH56tttoqe++9dyZPnpxkRtjasmXL3HHHHfnPf/6TU045JSeccEJuv/328mOee+65GTRoUG644Yb8+9//zvjx43PPPfdUOO+AAQNy8803Z+DAgXn99ddz9NFHZ6+99so///nP+Xn5AACQwYMHp1evXhk0aFD23HPPmi4HYL4rKSsrK6vpIgAWZvvuu2+++eab3HPPPenSpUumTZuWf/3rX0lmzKJt3Lhxdtppp9x8881JktGjR2eZZZbJ008/nfXXX7/KYx5++OEZPXp0+UzZpZdeOsccc0yOOeaY8uO2bds2a6+9du65555MmTIlSyyxRP7xj3+kc+fO5cc54IADMnny5AwePHhevgQAAJAuXbpkrbXWykorrZQTTzwx9957bzbddNOaLgugRtSp6QIAqGjNNdcs/+/atWtnySWXzBprrFE+1rx58yTJ559/Xj52xRVX5Prrr8+oUaPy7bffZurUqVlrrbWSJOPGjcuYMWOy3nrrVThuhw4dMn369CTJu+++m8mTJ2fLLbesUMvUqVOz9tprV/s1AgBAVe688858/vnn+fe//52OHTvWdDkANUZoC1AwdevWrfBzSUlJhbGSkpIkKQ9chwwZkmOOOSYXXHBBOnfunIYNG+b888/Ps88+O8fnnDhxYpIZ64Ytu+yyFZ4rLS39WdcBAABza+21187w4cNz/fXXZ9111y3vfQEWNkJbgAXcv//972ywwQY59NBDy8fee++98v9u3Lhxmjdvnueffz6bbLJJkhnLIwwfPrx8Nu5qq62W0tLSjBo1ylfQAACoMSussEIuuOCCdOnSJbVr187ll19e0yUB1AihLcACbqWVVsrNN9+cRx55JG3atMlf/vKXPP/882nTpk35NkcccUQGDBiQFVdcMe3atctll12Wr7/+unzmQsOGDXPMMcfk6KOPzvTp07PRRhtl3Lhx+fe//51GjRpln332qanLAwBgIbPyyivn8ccfT5cuXVKnTp1cfPHFNV0SwHwntAVYwB188MF56aWXsttuu6WkpCS9evXKoYcemoceeqh8m2OPPTajR49O7969U7t27Rx00EHp1q1bateuXb7NGWeckaWWWioDBgzIyJEjs/jii2edddbJCSecUBOXBQDAQmyVVVbJY489Vj7j9oILLqjpkgDmq5KysrKymi4CgPlr+vTpWXXVVdOzZ8+cccYZNV0OAAAA8BNm2gIsBD788MP8/e9/z6abbpopU6bk8ssvz/vvv5899tijpksDAAAA/ketmi4AgHmvVq1aufHGG9OxY8dsuOGGefXVV/OPf/wjq666ak2XBgAAAPwPyyMAAAAAABSImbYAAAAAAAUitAUAAAAAKBChLQAAAABAgQhtAQAAAAAKRGgLAAAAAFAgQlsAAAAAgAIR2gIAAAAAFIjQFgAAAACgQP4POQhf97Ai384AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 1400x600 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 416 ms, sys: 76.2 ms, total: 493 ms\n",
      "Wall time: 285 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "plot_metrics(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "476e9dac-b760-4d4a-a504-7745ca1bf9cf",
   "metadata": {},
   "source": [
    "## EfficientNet-B0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9b9dafb-e11a-4b2e-be73-16ab66fc3e88",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 311,
   "id": "7738323f-3f89-4235-9eb1-fcccb001ec29",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.48 s, sys: 43.2 ms, total: 1.53 s\n",
      "Wall time: 1.47 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "loss = tfr.keras.losses.ListMLELoss()\n",
    "model = EfficientNetB0RankingModel(loss)\n",
    "lr = 1e-2\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 312,
   "id": "7fba6d4f-1ee9-4635-8c90-7d544377f2af",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3.08 ms, sys: 120 μs, total: 3.2 ms\n",
      "Wall time: 1.51 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "batch = train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 313,
   "id": "6ac234c1-e6aa-4056-abc1-857b3d678355",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 9s 9s/step\n",
      "CPU times: user 8.64 s, sys: 333 ms, total: 8.98 s\n",
      "Wall time: 8.7 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 14:30:17.297735: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12509860583265284084\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[0.19675249, 0.19490921, 0.19572935, 0.19560882, 0.19503164],\n",
       "       [0.1955759 , 0.19603413, 0.19393641, 0.19602343, 0.19524834],\n",
       "       [0.19537053, 0.19585878, 0.19555992, 0.19509128, 0.19562998],\n",
       "       [0.19440651, 0.19441356, 0.19449607, 0.19420327, 0.19494164],\n",
       "       [0.20711109, 0.19930899, 0.2000733 , 0.19797668, 0.20055337],\n",
       "       [0.19434012, 0.19334951, 0.19733894, 0.19557133, 0.19478962],\n",
       "       [0.1974859 , 0.19503456, 0.19680957, 0.19710691, 0.19512567],\n",
       "       [0.19825208, 0.19604322, 0.20421478, 0.19697529, 0.20044214],\n",
       "       [0.19752827, 0.19573367, 0.19663769, 0.19615535, 0.19625306],\n",
       "       [0.1926674 , 0.19384837, 0.1944516 , 0.19996169, 0.19918634],\n",
       "       [0.19450673, 0.19485623, 0.19512795, 0.19582999, 0.19599776],\n",
       "       [0.19804987, 0.1930091 , 0.20236115, 0.19361277, 0.19886932],\n",
       "       [0.19510499, 0.19546753, 0.19626635, 0.19479534, 0.19463289],\n",
       "       [0.19534433, 0.19558035, 0.19504504, 0.19445919, 0.19518499],\n",
       "       [0.19529423, 0.19490598, 0.1956499 , 0.19436604, 0.19410336],\n",
       "       [0.19423246, 0.19414455, 0.19413944, 0.1959129 , 0.19399936],\n",
       "       [0.20276746, 0.1947943 , 0.19524454, 0.19718136, 0.19601628],\n",
       "       [0.19428371, 0.19457404, 0.19430164, 0.19455174, 0.19501469],\n",
       "       [0.20016432, 0.19812328, 0.19564703, 0.19652773, 0.19730094],\n",
       "       [0.19489513, 0.1954475 , 0.19714893, 0.19467078, 0.19469571],\n",
       "       [0.20189254, 0.2010153 , 0.19540754, 0.19722357, 0.1960279 ],\n",
       "       [0.19538566, 0.19557446, 0.19557625, 0.19670072, 0.19565174],\n",
       "       [0.19684406, 0.19381982, 0.19581792, 0.19638415, 0.19712824],\n",
       "       [0.19518447, 0.19489694, 0.1940242 , 0.19559458, 0.1957088 ],\n",
       "       [0.1952575 , 0.19457419, 0.19552428, 0.19486888, 0.19486386],\n",
       "       [0.19404285, 0.1933791 , 0.19523853, 0.19387278, 0.19435285],\n",
       "       [0.1946713 , 0.19470872, 0.19476809, 0.19571117, 0.19480252],\n",
       "       [0.19457832, 0.19368702, 0.19428855, 0.19407544, 0.19395366],\n",
       "       [0.19339943, 0.19667053, 0.19120891, 0.19030322, 0.19930366],\n",
       "       [0.19600679, 0.19707496, 0.19561774, 0.19691493, 0.19628417]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 313,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.predict(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 314,
   "id": "106b4230-7352-421c-9a2b-6279f856633b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"efficient_net_b0_ranking_model_3\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " efficientnetb0 (Functional  (None, 7, 7, 1280)        4049571   \n",
      " )                                                               \n",
      "                                                                 \n",
      " flatten_16 (Flatten)        multiple                  0         \n",
      "                                                                 \n",
      " sequential_32 (Sequential)  (None, 64)                32285632  \n",
      "                                                                 \n",
      " sequential_33 (Sequential)  (None, 1)                 65        \n",
      "                                                                 \n",
      " ranking_16 (Ranking)        multiple                  0 (unused)\n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 36335268 (138.61 MB)\n",
      "Trainable params: 32285697 (123.16 MB)\n",
      "Non-trainable params: 4049571 (15.45 MB)\n",
      "_________________________________________________________________\n",
      "CPU times: user 24.6 ms, sys: 3.87 ms, total: 28.5 ms\n",
      "Wall time: 26.2 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 315,
   "id": "1c915600-024b-4fb2-bc53-43e750fc094e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "    125/Unknown - 431s 3s/step - ndcg_metric: 0.7772 - mrr_metric: 0.9435 - opa_metric: 0.6308 - loss: 4.5680 - regularization_loss: 0.0000e+00 - total_loss: 4.5680"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 14:37:31.519009: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17833060689356480733\n",
      "2024-06-14 14:37:31.519054: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2748196339091374153\n",
      "2024-06-14 14:37:31.519065: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17424786731806597647\n",
      "2024-06-14 14:37:31.519075: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13535894772148257793\n",
      "2024-06-14 14:37:31.519083: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4332546701843564235\n",
      "2024-06-14 14:37:31.519090: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12110303264717908113\n",
      "2024-06-14 14:37:31.519098: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 994988788210125525\n",
      "2024-06-14 14:37:31.519105: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv i"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 434s 3s/step - ndcg_metric: 0.7772 - mrr_metric: 0.9435 - opa_metric: 0.6308 - loss: 4.5688 - regularization_loss: 0.0000e+00 - total_loss: 4.5688\n",
      "Epoch 2/10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "tem cancelled. Key hash: 17483172285531818939\n",
      "2024-06-14 14:37:31.519113: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17207925778548190583\n",
      "2024-06-14 14:37:31.519120: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3491396103798238519\n",
      "2024-06-14 14:37:31.519127: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13989459128702177095\n",
      "2024-06-14 14:37:31.519135: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2067273449582556057\n",
      "2024-06-14 14:37:31.519143: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2434843169122817963\n",
      "2024-06-14 14:37:31.519150: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15568821152410343471\n",
      "2024-06-14 14:37:31.519158: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 56807832699583663\n",
      "2024-06-14 14:37:31.519166: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16040961953823497321\n",
      "2024-06-14 14:37:31.519179: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17865532830337555086\n",
      "2024-06-14 14:37:31.519185: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16912052563122989324\n",
      "2024-06-14 14:37:31.519193: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4584459024023259878\n",
      "2024-06-14 14:37:31.519200: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15848655563306969916\n",
      "2024-06-14 14:37:31.519207: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15744476732560543222\n",
      "2024-06-14 14:37:31.519214: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10402076770961051762\n",
      "2024-06-14 14:37:31.519236: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14768520104370950942\n",
      "2024-06-14 14:37:31.519245: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9198441697121596488\n",
      "2024-06-14 14:37:31.519253: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6678319962389468598\n",
      "2024-06-14 14:37:31.519260: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3626978524075259822\n",
      "2024-06-14 14:37:31.519268: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12644408257046894556\n",
      "2024-06-14 14:37:31.519276: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8177686799467870780\n",
      "2024-06-14 14:37:31.519283: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12016004559770065520\n",
      "2024-06-14 14:37:31.519292: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 870809600261651646\n",
      "2024-06-14 14:37:31.519361: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14730708846889098686\n",
      "2024-06-14 14:37:31.519373: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5771693987001687768\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 425s 3s/step - ndcg_metric: 0.7970 - mrr_metric: 0.9493 - opa_metric: 0.6634 - loss: 4.4275 - regularization_loss: 0.0000e+00 - total_loss: 4.4275\n",
      "Epoch 3/10\n",
      "125/125 [==============================] - 425s 3s/step - ndcg_metric: 0.8022 - mrr_metric: 0.9448 - opa_metric: 0.6667 - loss: 4.4264 - regularization_loss: 0.0000e+00 - total_loss: 4.4264\n",
      "Epoch 4/10\n",
      "125/125 [==============================] - 425s 3s/step - ndcg_metric: 0.8109 - mrr_metric: 0.9529 - opa_metric: 0.6814 - loss: 4.3845 - regularization_loss: 0.0000e+00 - total_loss: 4.3845\n",
      "Epoch 5/10\n",
      "125/125 [==============================] - 425s 3s/step - ndcg_metric: 0.8167 - mrr_metric: 0.9588 - opa_metric: 0.6906 - loss: 4.3169 - regularization_loss: 0.0000e+00 - total_loss: 4.3169\n",
      "Epoch 6/10\n",
      "125/125 [==============================] - 424s 3s/step - ndcg_metric: 0.8195 - mrr_metric: 0.9599 - opa_metric: 0.6935 - loss: 4.2742 - regularization_loss: 0.0000e+00 - total_loss: 4.2742\n",
      "Epoch 7/10\n",
      "125/125 [==============================] - 424s 3s/step - ndcg_metric: 0.8224 - mrr_metric: 0.9592 - opa_metric: 0.6977 - loss: 4.2521 - regularization_loss: 0.0000e+00 - total_loss: 4.2521\n",
      "Epoch 8/10\n",
      "125/125 [==============================] - 426s 3s/step - ndcg_metric: 0.8230 - mrr_metric: 0.9605 - opa_metric: 0.7000 - loss: 4.2548 - regularization_loss: 0.0000e+00 - total_loss: 4.2548\n",
      "Epoch 9/10\n",
      "125/125 [==============================] - 427s 3s/step - ndcg_metric: 0.8236 - mrr_metric: 0.9612 - opa_metric: 0.7009 - loss: 4.2242 - regularization_loss: 0.0000e+00 - total_loss: 4.2242\n",
      "Epoch 10/10\n",
      "125/125 [==============================] - 426s 3s/step - ndcg_metric: 0.8229 - mrr_metric: 0.9627 - opa_metric: 0.6986 - loss: 4.2349 - regularization_loss: 0.0000e+00 - total_loss: 4.2349\n",
      "CPU times: user 1h 10min 38s, sys: 2min 30s, total: 1h 13min 9s\n",
      "Wall time: 1h 11min 2s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(train_dataset, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 316,
   "id": "7fe1e93a-f84d-4292-a3ff-1c6c01fbb3a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 15 μs, sys: 0 ns, total: 15 μs\n",
      "Wall time: 29.3 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 317,
   "id": "0bb2583c-3007-461d-a5c0-305a67bacfda",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n",
      "tf.Tensor([830.7962  830.01263 830.3658  828.1153  829.58905], shape=(5,), dtype=float32) tf.Tensor([4 1 3 0 2], shape=(5,), dtype=int32) tf.Tensor([4 2 3 0 1], shape=(5,), dtype=int32)\n",
      "tf.Tensor([829.77356 830.00793 829.58545 829.7817  828.32465], shape=(5,), dtype=float32) tf.Tensor([4 3 1 2 0], shape=(5,), dtype=int32) tf.Tensor([2 4 1 3 0], shape=(5,), dtype=int32)\n",
      "tf.Tensor([828.0013 829.6252 829.7382 829.948  829.7444], shape=(5,), dtype=float32) tf.Tensor([1 2 4 3 0], shape=(5,), dtype=int32) tf.Tensor([0 1 2 4 3], shape=(5,), dtype=int32)\n",
      "CPU times: user 4.81 s, sys: 140 ms, total: 4.95 s\n",
      "Wall time: 4.11 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = train_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abbc2e0f-6075-4b53-9394-43fbd517fb27",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "test_metrics = model.evaluate(test_images, test_labels, return_dict=True)\n",
    "print(f\"Test metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f9e8ee0-4b26-46a4-ad8e-913602264e15",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "val_metrics = model.evaluate(val_images, val_labels, return_dict=True)\n",
    "print(f\"Val metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 318,
   "id": "a92bad64-ee7a-4854-913c-a5e91688ee09",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c28e373a0>, 139758823336064), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c28e373a0>, 139758823336064), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c28e36a70>, 139758823336704), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c28e36a70>, 139758823336704), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c262a10>, 139758823332224), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c262a10>, 139758823332224), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1dac3ea9b0>, 139758823334304), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1dac3ea9b0>, 139758823334304), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c28acb0>, 139760922666736), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c28acb0>, 139760922666736), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1dac43dc00>, 139760922665136), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1dac43dc00>, 139760922665136), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c2d80a0>, 139758823485120), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c2d80a0>, 139758823485120), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6b5b21d0>, 139758823482320), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6b5b21d0>, 139758823482320), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c1002b0>, 139758823333504), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c1002b0>, 139758823333504), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1dac461fc0>, 139758823332784), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1dac461fc0>, 139758823332784), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c28e373a0>, 139758823336064), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c28e373a0>, 139758823336064), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c28e36a70>, 139758823336704), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c28e36a70>, 139758823336704), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c262a10>, 139758823332224), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c262a10>, 139758823332224), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1dac3ea9b0>, 139758823334304), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1dac3ea9b0>, 139758823334304), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c28acb0>, 139760922666736), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c28acb0>, 139760922666736), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1dac43dc00>, 139760922665136), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1dac43dc00>, 139760922665136), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c2d80a0>, 139758823485120), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c2d80a0>, 139758823485120), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6b5b21d0>, 139758823482320), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6b5b21d0>, 139758823482320), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c1002b0>, 139758823333504), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c4c1002b0>, 139758823333504), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1dac461fc0>, 139758823332784), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1dac461fc0>, 139758823332784), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/EfficientNetB0RankingModel_20240614_154123_freezed_0.01/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/EfficientNetB0RankingModel_20240614_154123_freezed_0.01/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved in tesla_saved_models/EfficientNetB0RankingModel_20240614_154123_freezed_0.01 as EfficientNetB0RankingModel_20240614_154123_freezed_0.01\n",
      "CPU times: user 16.9 s, sys: 745 ms, total: 17.6 s\n",
      "Wall time: 17.5 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if SAVE_MODEL:\n",
    "    save_model_with_timestamp_and_lr(model, lr, freezed=\"freezed\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39536225-bbcb-40bd-9e35-fe76ed438b35",
   "metadata": {},
   "source": [
    "### Evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 386,
   "id": "78e1789d-ade7-495a-9872-46f92944c40e",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = tf.keras.models.load_model('tesla_saved_models/EfficientNetB0RankingModel_20240614_120318_freezed_0.001', \n",
    "                                          custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 387,
   "id": "390bbc18-10f6-4df1-8855-ad63a5b80b51",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 1s 1s/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 1it [00:04,  4.59s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 2it [00:10,  5.26s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "File \u001b[0;32m<timed exec>:1\u001b[0m\n",
      "File \u001b[0;32m~/diploma/utils.py:203\u001b[0m, in \u001b[0;36mcompute_dataset_metrics\u001b[0;34m(model, dataset, max_k)\u001b[0m\n\u001b[1;32m    192\u001b[0m aggregated_metrics \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m    193\u001b[0m     \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mNDCG\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;241m0.0\u001b[39m,\n\u001b[1;32m    194\u001b[0m     \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpermutation_accuracy\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;241m0.0\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    199\u001b[0m     \u001b[38;5;124m'\u001b[39m\u001b[38;5;124maccuracy@k\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39mzeros(max_k)\n\u001b[1;32m    200\u001b[0m }\n\u001b[1;32m    201\u001b[0m total_batches \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m--> 203\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m images, labels \u001b[38;5;129;01min\u001b[39;00m tqdm(dataset, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mProcessing batches\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m    204\u001b[0m     scores \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mpredict(images)\n\u001b[1;32m    205\u001b[0m     ranks \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39margsort(np\u001b[38;5;241m.\u001b[39margsort(scores, axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m), axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
      "File \u001b[0;32m~/miniconda3/envs/diploma/lib/python3.10/site-packages/tqdm/std.py:1181\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1178\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m   1180\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1181\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[1;32m   1182\u001b[0m         \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m   1183\u001b[0m         \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m   1184\u001b[0m         \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/diploma/lib/python3.10/site-packages/tensorflow/python/data/ops/iterator_ops.py:810\u001b[0m, in \u001b[0;36mOwnedIterator.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    808\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__next__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    809\u001b[0m   \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 810\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_internal\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    811\u001b[0m   \u001b[38;5;28;01mexcept\u001b[39;00m errors\u001b[38;5;241m.\u001b[39mOutOfRangeError:\n\u001b[1;32m    812\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/diploma/lib/python3.10/site-packages/tensorflow/python/data/ops/iterator_ops.py:773\u001b[0m, in \u001b[0;36mOwnedIterator._next_internal\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    770\u001b[0m \u001b[38;5;66;03m# TODO(b/77291417): This runs in sync mode as iterators use an error status\u001b[39;00m\n\u001b[1;32m    771\u001b[0m \u001b[38;5;66;03m# to communicate that there is no more data to iterate over.\u001b[39;00m\n\u001b[1;32m    772\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context\u001b[38;5;241m.\u001b[39mexecution_mode(context\u001b[38;5;241m.\u001b[39mSYNC):\n\u001b[0;32m--> 773\u001b[0m   ret \u001b[38;5;241m=\u001b[39m \u001b[43mgen_dataset_ops\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43miterator_get_next\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    774\u001b[0m \u001b[43m      \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_iterator_resource\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    775\u001b[0m \u001b[43m      \u001b[49m\u001b[43moutput_types\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_flat_output_types\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    776\u001b[0m \u001b[43m      \u001b[49m\u001b[43moutput_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_flat_output_shapes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    778\u001b[0m   \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    779\u001b[0m     \u001b[38;5;66;03m# Fast path for the case `self._structure` is not a nested structure.\u001b[39;00m\n\u001b[1;32m    780\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_element_spec\u001b[38;5;241m.\u001b[39m_from_compatible_tensor_list(ret)  \u001b[38;5;66;03m# pylint: disable=protected-access\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/diploma/lib/python3.10/site-packages/tensorflow/python/ops/gen_dataset_ops.py:3024\u001b[0m, in \u001b[0;36miterator_get_next\u001b[0;34m(iterator, output_types, output_shapes, name)\u001b[0m\n\u001b[1;32m   3022\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tld\u001b[38;5;241m.\u001b[39mis_eager:\n\u001b[1;32m   3023\u001b[0m   \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3024\u001b[0m     _result \u001b[38;5;241m=\u001b[39m \u001b[43mpywrap_tfe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTFE_Py_FastPathExecute\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   3025\u001b[0m \u001b[43m      \u001b[49m\u001b[43m_ctx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mIteratorGetNext\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moutput_types\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_types\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3026\u001b[0m \u001b[43m      \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moutput_shapes\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_shapes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3027\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m _result\n\u001b[1;32m   3028\u001b[0m   \u001b[38;5;28;01mexcept\u001b[39;00m _core\u001b[38;5;241m.\u001b[39m_NotOkStatusException \u001b[38;5;28;01mas\u001b[39;00m e:\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 388,
   "id": "e9185a23-b38d-45c0-9622-7d30c00e3126",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'NDCG': 0.8922528057539286, 'permutation_accuracy': <tf.Tensor: shape=(), dtype=float32, numpy=0.37481487>, 'Kendalls Tau': 0.38419753086419745, 'Spearmans Rho': 0.45580246913580236, 'OPA': 0.6920987654320989, 'accuracy_by_rank': array([0.51728396, 0.36666667, 0.30617285, 0.28765433, 0.3962963 ]), 'accuracy@k': array([       nan, 0.3962963 , 0.60555556, 0.76213992, 0.87932099])}\n",
      "CPU times: user 648 μs, sys: 25 μs, total: 673 μs\n",
      "Wall time: 566 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 389,
   "id": "5a878f4c-aa9c-4da3-bece-6466ca018db7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 1it [00:00,  3.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 100ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 2it [00:00,  3.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 93ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 3it [00:00,  3.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 89ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 4it [00:01,  3.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 103ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 5it [00:01,  3.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 102ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 6it [00:01,  3.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 94ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 7it [00:01,  3.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 100ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 8it [00:02,  3.83it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 100ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 9it [00:02,  3.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 10it [00:02,  3.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 100ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 11it [00:02,  3.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 102ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 12it [00:03,  3.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 100ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 13it [00:03,  3.87it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 100ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 14it [00:03,  3.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 97ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 15it [00:03,  3.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 102ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 16it [00:04,  3.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 101ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 17it [00:04,  3.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 102ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 18it [00:04,  3.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 101ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 19it [00:04,  3.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 101ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 19it [00:05,  3.65it/s]\n",
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 390,
   "id": "c49b0c80-b63b-486b-ae1d-2e4192a02cb5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'NDCG': 0.8922528057539286, 'permutation_accuracy': <tf.Tensor: shape=(), dtype=float32, numpy=0.37481487>, 'Kendalls Tau': 0.38419753086419745, 'Spearmans Rho': 0.45580246913580236, 'OPA': 0.6920987654320989, 'accuracy_by_rank': array([0.51728396, 0.36666667, 0.30617285, 0.28765433, 0.3962963 ]), 'accuracy@k': array([       nan, 0.3962963 , 0.60555556, 0.76213992, 0.87932099])}\n",
      "CPU times: user 431 μs, sys: 17 μs, total: 448 μs\n",
      "Wall time: 441 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 391,
   "id": "c6e236ff-22f1-4763-a7d7-1817d535ab40",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 102ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 0it [00:00, ?it/s]\n",
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 392,
   "id": "369024a3-a3b4-41f0-b8a3-409c694a275d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'NDCG': 0.8922528057539286, 'permutation_accuracy': <tf.Tensor: shape=(), dtype=float32, numpy=0.37481487>, 'Kendalls Tau': 0.38419753086419745, 'Spearmans Rho': 0.45580246913580236, 'OPA': 0.6920987654320989, 'accuracy_by_rank': array([0.51728396, 0.36666667, 0.30617285, 0.28765433, 0.3962963 ]), 'accuracy@k': array([       nan, 0.3962963 , 0.60555556, 0.76213992, 0.87932099])}\n",
      "CPU times: user 408 μs, sys: 16 μs, total: 424 μs\n",
      "Wall time: 421 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 393,
   "id": "54c1c335-129a-4363-ae74-e778a3d80d1e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 7 μs, sys: 0 ns, total: 7 μs\n",
      "Wall time: 14.1 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 394,
   "id": "ef8b8217-30ba-4e42-abb3-5099e18d66af",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "Exception encountered when calling layer 'block6d_se_reshape' (type Reshape).\n\nExpected `context` argument in EagerTensor constructor to have a `_handle` attribute but it did not. Was eager Context initialized?\n\nCall arguments received by layer 'block6d_se_reshape' (type Reshape):\n  • inputs=tf.Tensor(shape=(150, 1152), dtype=float32)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "File \u001b[0;32m<timed exec>:5\u001b[0m\n",
      "File \u001b[0;32m~/miniconda3/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:70\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     67\u001b[0m     filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[1;32m     68\u001b[0m     \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[1;32m     69\u001b[0m     \u001b[38;5;66;03m# `tf.debugging.disable_traceback_filtering()`\u001b[39;00m\n\u001b[0;32m---> 70\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m     71\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m     72\u001b[0m     \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n",
      "File \u001b[0;32m~/diploma/models/efficientnetb0.py:71\u001b[0m, in \u001b[0;36mEfficientNetB0RankingModel.call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m     69\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcall\u001b[39m(\u001b[38;5;28mself\u001b[39m, inputs):\n\u001b[1;32m     70\u001b[0m     flatten_inputs \u001b[38;5;241m=\u001b[39m tf\u001b[38;5;241m.\u001b[39mreshape(inputs, [\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, IMAGENET_INPUT_H, IMAGENET_INPUT_W, IMAGENET_INPUT_D])\n\u001b[0;32m---> 71\u001b[0m     embeddings \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mflatten_inputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     72\u001b[0m     flattened_embeddings \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mflatten(embeddings)\n\u001b[1;32m     73\u001b[0m     dense_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdense_layers(flattened_embeddings)\n",
      "\u001b[0;31mTypeError\u001b[0m: Exception encountered when calling layer 'block6d_se_reshape' (type Reshape).\n\nExpected `context` argument in EagerTensor constructor to have a `_handle` attribute but it did not. Was eager Context initialized?\n\nCall arguments received by layer 'block6d_se_reshape' (type Reshape):\n  • inputs=tf.Tensor(shape=(150, 1152), dtype=float32)"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = test_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 395,
   "id": "a3d688c4-3a9c-4c3d-bc8e-8c31e2f54f21",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "KeyboardInterrupt\n",
      "\n",
      "Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f2037414790>>\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/user/miniconda3/envs/diploma/lib/python3.10/site-packages/ipykernel/ipkernel.py\", line 775, in _clean_thread_parent_frames\n",
      "    def _clean_thread_parent_frames(\n",
      "KeyboardInterrupt: \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error in callback <function flush_figures at 0x7f1d99b52d40> (for post_execute), with arguments args (),kwargs {}:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "plot_metrics(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53cd06eb-c7f8-467a-9414-2c47900fce4f",
   "metadata": {},
   "source": [
    "## EfficientNet-B1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0fc421b-1e24-48eb-b1e1-387f28bad2e7",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 295,
   "id": "4cbb3349-ae03-4d3a-9fd7-2015205bebd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.26 s, sys: 55.7 ms, total: 2.32 s\n",
      "Wall time: 2.25 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "loss = tfr.keras.losses.ListMLELoss()\n",
    "model = EfficientNetB1RankingModel(loss)\n",
    "lr = 1e-2\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 296,
   "id": "3310a575-c6f5-427e-83a6-3d0d83c78a36",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.6 ms, sys: 102 μs, total: 2.7 ms\n",
      "Wall time: 1.18 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "batch = train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 297,
   "id": "e9522279-3ff3-4cf2-babe-9b31de2684da",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 10s 10s/step\n",
      "CPU times: user 9.61 s, sys: 427 ms, total: 10 s\n",
      "Wall time: 9.73 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 12:03:52.971960: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12236039780566329570\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[-0.44861716, -0.4473724 , -0.44786367, -0.4467623 , -0.4502459 ],\n",
       "       [-0.44739646, -0.44736022, -0.4467301 , -0.4441615 , -0.44753215],\n",
       "       [-0.44825664, -0.44792867, -0.44759896, -0.44845134, -0.44888973],\n",
       "       [-0.44848698, -0.44674063, -0.44719857, -0.4480282 , -0.44518495],\n",
       "       [-0.45421326, -0.4491125 , -0.45053026, -0.46408835, -0.44994348],\n",
       "       [-0.44703037, -0.4472928 , -0.44871193, -0.45124754, -0.44981855],\n",
       "       [-0.4467917 , -0.44730407, -0.45021456, -0.4476329 , -0.44759285],\n",
       "       [-0.46015668, -0.44973052, -0.4506296 , -0.43629724, -0.44845125],\n",
       "       [-0.44913438, -0.44935626, -0.4484954 , -0.4489308 , -0.44882777],\n",
       "       [-0.44932956, -0.4448642 , -0.4416835 , -0.44249398, -0.4479276 ],\n",
       "       [-0.4474766 , -0.44740832, -0.44705287, -0.44761962, -0.44683856],\n",
       "       [-0.44742107, -0.45570678, -0.44409263, -0.45542106, -0.45026657],\n",
       "       [-0.44820383, -0.4488032 , -0.4507006 , -0.44865105, -0.4491079 ],\n",
       "       [-0.44941717, -0.4475036 , -0.44613022, -0.45000333, -0.44787523],\n",
       "       [-0.44884992, -0.44867855, -0.44861737, -0.4496535 , -0.44729117],\n",
       "       [-0.44672942, -0.44671923, -0.44513997, -0.4479093 , -0.44485492],\n",
       "       [-0.44570547, -0.44683766, -0.44709238, -0.44557294, -0.44650373],\n",
       "       [-0.44722325, -0.44814938, -0.44741172, -0.44790113, -0.4476341 ],\n",
       "       [-0.45374486, -0.44842348, -0.45149127, -0.45005155, -0.47161967],\n",
       "       [-0.4483904 , -0.44742694, -0.44920534, -0.44762146, -0.44892898],\n",
       "       [-0.4472748 , -0.45934147, -0.4485613 , -0.4595956 , -0.45967183],\n",
       "       [-0.4441376 , -0.4479509 , -0.44825426, -0.4469994 , -0.44712678],\n",
       "       [-0.44682136, -0.4492783 , -0.4483858 , -0.44975832, -0.44528815],\n",
       "       [-0.4430675 , -0.44700465, -0.4461583 , -0.44700038, -0.44610718],\n",
       "       [-0.44937137, -0.448581  , -0.44595888, -0.44973734, -0.44920528],\n",
       "       [-0.44738117, -0.4481543 , -0.44741863, -0.44760785, -0.4474491 ],\n",
       "       [-0.4472811 , -0.44801268, -0.4478158 , -0.4469968 , -0.44747952],\n",
       "       [-0.44759595, -0.44788074, -0.4482642 , -0.44830835, -0.44913086],\n",
       "       [-0.44274345, -0.44075543, -0.44620025, -0.44131866, -0.44914657],\n",
       "       [-0.4449475 , -0.44907346, -0.44939277, -0.44739893, -0.44609827]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 297,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.predict(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 298,
   "id": "f7c96308-daee-4bb3-8aaa-5486f3aa8d1b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"efficient_net_b1_ranking_model\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " efficientnetb1 (Functional  (None, 7, 7, 1280)        6575239   \n",
      " )                                                               \n",
      "                                                                 \n",
      " flatten_14 (Flatten)        multiple                  0         \n",
      "                                                                 \n",
      " sequential_28 (Sequential)  (None, 64)                32285632  \n",
      "                                                                 \n",
      " sequential_29 (Sequential)  (None, 1)                 65        \n",
      "                                                                 \n",
      " ranking_14 (Ranking)        multiple                  0 (unused)\n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 38860936 (148.24 MB)\n",
      "Trainable params: 32285697 (123.16 MB)\n",
      "Non-trainable params: 6575239 (25.08 MB)\n",
      "_________________________________________________________________\n",
      "CPU times: user 35 ms, sys: 0 ns, total: 35 ms\n",
      "Wall time: 32.1 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 299,
   "id": "1bc5088c-73fa-49fc-9ec8-f9acd5dfba18",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "    125/Unknown - 447s 4s/step - ndcg_metric: 0.7581 - mrr_metric: 0.9225 - opa_metric: 0.6006 - loss: 4.8588 - regularization_loss: 0.0000e+00 - total_loss: 4.8588"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 12:11:23.496617: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17185729678796480377\n",
      "2024-06-14 12:11:23.496662: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17934231380188029487\n",
      "2024-06-14 12:11:23.496673: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12545845259178288595\n",
      "2024-06-14 12:11:23.496680: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10048011897943352617\n",
      "2024-06-14 12:11:23.496689: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10518070551724558455\n",
      "2024-06-14 12:11:23.496696: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1507002403021571385\n",
      "2024-06-14 12:11:23.496704: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17157022900605095531\n",
      "2024-06-14 12:11:23.496711: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous rec"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 451s 4s/step - ndcg_metric: 0.7581 - mrr_metric: 0.9225 - opa_metric: 0.6006 - loss: 4.8552 - regularization_loss: 0.0000e+00 - total_loss: 4.8552\n",
      "Epoch 2/10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "v item cancelled. Key hash: 16802513923252580631\n",
      "2024-06-14 12:11:23.496719: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6531947471702364895\n",
      "2024-06-14 12:11:23.496726: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4136865420836369309\n",
      "2024-06-14 12:11:23.496733: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16630275330140082463\n",
      "2024-06-14 12:11:23.496740: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9870101696653423895\n",
      "2024-06-14 12:11:23.496748: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3584354210964846163\n",
      "2024-06-14 12:11:23.496756: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17592151053366817911\n",
      "2024-06-14 12:11:23.496764: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16473717737596885931\n",
      "2024-06-14 12:11:23.496771: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9206347368613699441\n",
      "2024-06-14 12:11:23.496785: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5209917563158725722\n",
      "2024-06-14 12:11:23.496792: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17353611419382295224\n",
      "2024-06-14 12:11:23.496799: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6928336483501075546\n",
      "2024-06-14 12:11:23.496806: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6836649655122881836\n",
      "2024-06-14 12:11:23.496813: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14450935776520719870\n",
      "2024-06-14 12:11:23.496820: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11987253178149031690\n",
      "2024-06-14 12:11:23.496828: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16653572390649098266\n",
      "2024-06-14 12:11:23.496835: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17088557599765868742\n",
      "2024-06-14 12:11:23.496842: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13450196827437289804\n",
      "2024-06-14 12:11:23.496849: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7293075784703751104\n",
      "2024-06-14 12:11:23.496856: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12374655626325021514\n",
      "2024-06-14 12:11:23.496863: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2532283602851149302\n",
      "2024-06-14 12:11:23.496872: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9641376884039380400\n",
      "2024-06-14 12:11:23.496895: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5158781655487453760\n",
      "2024-06-14 12:11:23.496996: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17523664218585635964\n",
      "2024-06-14 12:11:23.497008: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3116866217782255382\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 440s 4s/step - ndcg_metric: 0.7967 - mrr_metric: 0.9528 - opa_metric: 0.6624 - loss: 4.4719 - regularization_loss: 0.0000e+00 - total_loss: 4.4719\n",
      "Epoch 3/10\n",
      "125/125 [==============================] - 440s 4s/step - ndcg_metric: 0.7993 - mrr_metric: 0.9487 - opa_metric: 0.6632 - loss: 4.5008 - regularization_loss: 0.0000e+00 - total_loss: 4.5008\n",
      "Epoch 4/10\n",
      "125/125 [==============================] - 440s 4s/step - ndcg_metric: 0.8046 - mrr_metric: 0.9560 - opa_metric: 0.6739 - loss: 4.3970 - regularization_loss: 0.0000e+00 - total_loss: 4.3970\n",
      "Epoch 5/10\n",
      "125/125 [==============================] - 440s 4s/step - ndcg_metric: 0.8051 - mrr_metric: 0.9545 - opa_metric: 0.6789 - loss: 4.3879 - regularization_loss: 0.0000e+00 - total_loss: 4.3879\n",
      "Epoch 6/10\n",
      "125/125 [==============================] - 440s 4s/step - ndcg_metric: 0.8105 - mrr_metric: 0.9559 - opa_metric: 0.6838 - loss: 4.3419 - regularization_loss: 0.0000e+00 - total_loss: 4.3419\n",
      "Epoch 7/10\n",
      "125/125 [==============================] - 440s 4s/step - ndcg_metric: 0.8111 - mrr_metric: 0.9559 - opa_metric: 0.6845 - loss: 4.3584 - regularization_loss: 0.0000e+00 - total_loss: 4.3584\n",
      "Epoch 8/10\n",
      "125/125 [==============================] - 441s 4s/step - ndcg_metric: 0.8113 - mrr_metric: 0.9532 - opa_metric: 0.6814 - loss: 4.3528 - regularization_loss: 0.0000e+00 - total_loss: 4.3528\n",
      "Epoch 9/10\n",
      "125/125 [==============================] - 430s 3s/step - ndcg_metric: 0.7992 - mrr_metric: 0.9423 - opa_metric: 0.6581 - loss: 4.5979 - regularization_loss: 0.0000e+00 - total_loss: 4.5979\n",
      "Epoch 10/10\n",
      "125/125 [==============================] - 430s 3s/step - ndcg_metric: 0.8147 - mrr_metric: 0.9557 - opa_metric: 0.6859 - loss: 4.3610 - regularization_loss: 0.0000e+00 - total_loss: 4.3610\n",
      "CPU times: user 1h 12min 18s, sys: 2min 38s, total: 1h 14min 56s\n",
      "Wall time: 1h 13min 13s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(train_dataset, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 300,
   "id": "2f6863f3-6ae9-4c41-90b1-f6659f4c5627",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 14 μs, sys: 0 ns, total: 14 μs\n",
      "Wall time: 27.9 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 301,
   "id": "e2a911f4-d957-4342-badd-c25429ca2560",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n",
      "tf.Tensor([929.06537 928.09735 928.8401  928.703   929.4992 ], shape=(5,), dtype=float32) tf.Tensor([1 0 2 3 4], shape=(5,), dtype=int32) tf.Tensor([3 0 2 1 4], shape=(5,), dtype=int32)\n",
      "tf.Tensor([927.8751  929.10925 929.1825  929.04584 929.2229 ], shape=(5,), dtype=float32) tf.Tensor([0 2 3 1 4], shape=(5,), dtype=int32) tf.Tensor([0 2 3 1 4], shape=(5,), dtype=int32)\n",
      "tf.Tensor([929.04236 928.0904  928.7208  928.96    928.7456 ], shape=(5,), dtype=float32) tf.Tensor([3 1 0 4 2], shape=(5,), dtype=int32) tf.Tensor([4 0 1 3 2], shape=(5,), dtype=int32)\n",
      "CPU times: user 5.61 s, sys: 177 ms, total: 5.79 s\n",
      "Wall time: 4.67 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = train_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f0a22c7-88a9-42a1-b1e3-edaab37b0cfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "test_metrics = model.evaluate(test_images, test_labels, return_dict=True)\n",
    "print(f\"Test metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "beff0b4e-b47f-47f5-b902-62191e5556d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "val_metrics = model.evaluate(val_images, val_labels, return_dict=True)\n",
    "print(f\"Val metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 302,
   "id": "041d408e-ed08-4400-9359-fcc40e3bc39a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b1ca080>, 139759987427968), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b1ca080>, 139759987427968), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b1c8820>, 139759768162336), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b1c8820>, 139759768162336), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b29c5b0>, 139759987429408), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b29c5b0>, 139759987429408), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b29ebf0>, 139759987430528), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b29ebf0>, 139759987430528), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10fe50>, 139759987439728), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10fe50>, 139759987439728), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10d120>, 139759987427168), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10d120>, 139759987427168), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10e5f0>, 139759998155024), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10e5f0>, 139759998155024), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10c460>, 139759998155104), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10c460>, 139759998155104), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b0eb910>, 139759998156544), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b0eb910>, 139759998156544), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b0ea170>, 139759998155744), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b0ea170>, 139759998155744), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b1ca080>, 139759987427968), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b1ca080>, 139759987427968), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b1c8820>, 139759768162336), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b1c8820>, 139759768162336), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b29c5b0>, 139759987429408), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b29c5b0>, 139759987429408), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b29ebf0>, 139759987430528), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b29ebf0>, 139759987430528), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10fe50>, 139759987439728), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10fe50>, 139759987439728), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10d120>, 139759987427168), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10d120>, 139759987427168), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10e5f0>, 139759998155024), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10e5f0>, 139759998155024), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10c460>, 139759998155104), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b10c460>, 139759998155104), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b0eb910>, 139759998156544), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b0eb910>, 139759998156544), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b0ea170>, 139759998155744), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c5b0ea170>, 139759998155744), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/EfficientNetB1RankingModel_20240614_131711_freezed_0.01/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/EfficientNetB1RankingModel_20240614_131711_freezed_0.01/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved in tesla_saved_models/EfficientNetB1RankingModel_20240614_131711_freezed_0.01 as EfficientNetB1RankingModel_20240614_131711_freezed_0.01\n",
      "CPU times: user 24.1 s, sys: 879 ms, total: 25 s\n",
      "Wall time: 24.9 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if SAVE_MODEL:\n",
    "    save_model_with_timestamp_and_lr(model, lr, freezed=\"freezed\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "624af6fa-08b6-4a30-808b-0415f33d1db3",
   "metadata": {},
   "source": [
    "### Evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 397,
   "id": "079fdeb7-1558-44c0-8a33-eddebc8bbbdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = tf.keras.models.load_model('saved_models/EfficientNetB1RankingModel_20240612_100333', \n",
    "                                          custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a48a49d5-471e-426c-9954-13e914d36c6d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a3fde0a-6261-4d01-a68a-39d5880de4d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce94bd9e-62b2-4987-a50d-f26f173f8e6b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "367605d9-af6e-416a-88ab-bc5c54e75992",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acf6bdcc-4bcb-49c9-98d5-fd65d71e9020",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "033ee5a3-1b82-4e30-914c-a731fd0b5245",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "374e2017-bdde-4a6f-90ce-134baa136896",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06503066-f5d7-4dfb-9ea2-4d43538c190e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = test_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "168200ab-96b2-47ef-8681-4fa73e1f0ecf",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "plot_metrics(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4904b9bd-7b85-4302-bdbd-411602cfc5fc",
   "metadata": {},
   "source": [
    "## EfficientNet-B2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca61962f-ff63-4e52-9c5d-be62a0f2cb70",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 303,
   "id": "c6f887b7-9ca2-40ca-adfe-0ae47626a3e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.13 s, sys: 28 ms, total: 2.16 s\n",
      "Wall time: 2.09 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "loss = tfr.keras.losses.ListMLELoss()\n",
    "model = EfficientNetB2RankingModel(loss)\n",
    "lr = 1e-2\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 304,
   "id": "09f13d07-001d-41da-8e88-a2f1a2c064a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3.21 ms, sys: 14 μs, total: 3.22 ms\n",
      "Wall time: 1.4 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "batch = train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 305,
   "id": "3b20a21c-d740-4f44-bf54-3afef1747dd4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 9s 9s/step\n",
      "CPU times: user 9.29 s, sys: 450 ms, total: 9.74 s\n",
      "Wall time: 9.43 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 13:17:47.937495: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12236039780566329570\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[-0.24431625, -0.2356162 , -0.23422995, -0.23644073, -0.23674856],\n",
       "       [-0.24522087, -0.23834218, -0.23971614, -0.240123  , -0.24064183],\n",
       "       [-0.24038768, -0.2403    , -0.24074358, -0.23976347, -0.24041234],\n",
       "       [-0.24102136, -0.24061777, -0.24028185, -0.23930073, -0.24039212],\n",
       "       [-0.23702009, -0.2237636 , -0.23241492, -0.23658067, -0.23149948],\n",
       "       [-0.24238999, -0.23751706, -0.23515847, -0.23561887, -0.24530521],\n",
       "       [-0.2362248 , -0.23573819, -0.23504043, -0.23625186, -0.23687339],\n",
       "       [-0.22067134, -0.2302644 , -0.23390171, -0.23201916, -0.21556017],\n",
       "       [-0.24003288, -0.24195808, -0.24045852, -0.23665032, -0.23754218],\n",
       "       [-0.22095527, -0.23180842, -0.22410676, -0.22339772, -0.22072947],\n",
       "       [-0.23716733, -0.239032  , -0.23895997, -0.23934728, -0.23942623],\n",
       "       [-0.21810204, -0.23302975, -0.2264873 , -0.23816746, -0.23354365],\n",
       "       [-0.23871675, -0.23558232, -0.2355711 , -0.2348932 , -0.23906553],\n",
       "       [-0.23745833, -0.23789734, -0.23926777, -0.24014577, -0.2400221 ],\n",
       "       [-0.23381813, -0.23961757, -0.23785323, -0.2384641 , -0.23891276],\n",
       "       [-0.23923674, -0.23929682, -0.24101208, -0.23908809, -0.24099162],\n",
       "       [-0.23436916, -0.2357154 , -0.2360313 , -0.23140198, -0.23509622],\n",
       "       [-0.23879918, -0.24113414, -0.23973511, -0.2409378 , -0.2412869 ],\n",
       "       [-0.23782699, -0.21284686, -0.23340149, -0.23507647, -0.23472454],\n",
       "       [-0.23875713, -0.23894769, -0.24034294, -0.2381708 , -0.23863333],\n",
       "       [-0.23947224, -0.23549859, -0.22356306, -0.23576   , -0.22947109],\n",
       "       [-0.23492546, -0.23869601, -0.23875484, -0.236043  , -0.23901619],\n",
       "       [-0.23871747, -0.23367125, -0.23365918, -0.23577373, -0.23867197],\n",
       "       [-0.23851442, -0.24108121, -0.24151194, -0.24073093, -0.24146852],\n",
       "       [-0.2397461 , -0.239889  , -0.2389364 , -0.23723277, -0.2420284 ],\n",
       "       [-0.23730025, -0.23967151, -0.24123469, -0.24077557, -0.23991053],\n",
       "       [-0.23976727, -0.23898107, -0.23933205, -0.23926167, -0.2400063 ],\n",
       "       [-0.23693761, -0.24041703, -0.24020004, -0.24007943, -0.24048086],\n",
       "       [-0.22413816, -0.22213912, -0.22483306, -0.23388937, -0.2254158 ],\n",
       "       [-0.23882568, -0.23819023, -0.23728514, -0.23676053, -0.23870684]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 305,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.predict(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 306,
   "id": "3d63afa9-8be9-4f83-b1df-1913a7f0684b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"efficient_net_b2_ranking_model_2\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " efficientnetb2 (Functional  (None, 7, 7, 1408)        7768569   \n",
      " )                                                               \n",
      "                                                                 \n",
      " flatten_15 (Flatten)        multiple                  0         \n",
      "                                                                 \n",
      " sequential_30 (Sequential)  (None, 64)                35496896  \n",
      "                                                                 \n",
      " sequential_31 (Sequential)  (None, 1)                 65        \n",
      "                                                                 \n",
      " ranking_15 (Ranking)        multiple                  0 (unused)\n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 43265530 (165.04 MB)\n",
      "Trainable params: 35496961 (135.41 MB)\n",
      "Non-trainable params: 7768569 (29.63 MB)\n",
      "_________________________________________________________________\n",
      "CPU times: user 30.3 ms, sys: 3.95 ms, total: 34.2 ms\n",
      "Wall time: 31.6 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 307,
   "id": "3eabaaba-2ea9-4d4b-af4f-8e605de628a0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "    125/Unknown - 436s 3s/step - ndcg_metric: 0.7477 - mrr_metric: 0.9196 - opa_metric: 0.5773 - loss: 5.2971 - regularization_loss: 0.0000e+00 - total_loss: 5.2971"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 13:25:07.894500: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17135038215022731337\n",
      "2024-06-14 13:25:07.894571: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9831072595876937903\n",
      "2024-06-14 13:25:07.894594: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13200909488153263943\n",
      "2024-06-14 13:25:07.894610: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 493421264522331993\n",
      "2024-06-14 13:25:07.894626: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1270010923177469863\n",
      "2024-06-14 13:25:07.894644: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11037030121684978685\n",
      "2024-06-14 13:25:07.894663: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17157022900605095531\n",
      "2024-06-14 13:25:07.894689: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13127562949493346195\n",
      "2024-06-14 13:25:07.894715: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13695389863566061629\n",
      "2024-06-14 13:25:07.894742: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 178534479207338621\n",
      "2024-06-14 13:25:07.894767: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13382154064904926309\n",
      "2024-06-14 13:25:07.894793: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7093833028984133667\n",
      "2024-06-14 13:25:07.894821: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9696221818698095079\n",
      "2024-06-14 13:25:07.894845: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16048656414016575581\n",
      "2024-06-14 13:25:07.894872: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12119068419854916647\n",
      "2024-06-14 13:25:07.894899: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 147161527434312755\n",
      "2024-06-14 13:25:07.894926: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1200190052937129893\n",
      "2024-06-14 13:25:07.894966: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7220601810873696478\n",
      "2024-06-14 13:25:07.894992: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1513739858723832088\n",
      "2024-06-14 13:25:07.895019: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16346056015531131702\n",
      "2024-06-14 13:25:07.895050: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8483212900484458216\n",
      "2024-06-14 13:25:07.895081: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8578633034348292240\n",
      "2024-06-14 13:25:07.895110: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7805449860732438904\n",
      "2024-06-14 13:25:07.895138: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3090062051100437360\n",
      "2024-06-14 13:25:07.895168: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15247288628218433230\n",
      "2024-06-14 13:25:07.895196: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7764637174391148542\n",
      "2024-06-14 13:25:07.895223: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12955912133959897468\n",
      "2024-06-14 13:25:07.895252: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7293075784703751104\n",
      "2024-06-14 13:25:07.895278: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12159700787590215346\n",
      "2024-06-14 13:25:07.895306: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1665721468661144582\n",
      "2024-06-14 13:25:07.895333: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9836333925907168492\n",
      "2024-06-14 13:25:07.895484: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7921111971223175096\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 440s 3s/step - ndcg_metric: 0.7477 - mrr_metric: 0.9196 - opa_metric: 0.5773 - loss: 5.2912 - regularization_loss: 0.0000e+00 - total_loss: 5.2912\n",
      "Epoch 2/10\n",
      "125/125 [==============================] - 430s 3s/step - ndcg_metric: 0.7798 - mrr_metric: 0.9437 - opa_metric: 0.6374 - loss: 4.6099 - regularization_loss: 0.0000e+00 - total_loss: 4.6099\n",
      "Epoch 3/10\n",
      "125/125 [==============================] - 432s 3s/step - ndcg_metric: 0.7946 - mrr_metric: 0.9515 - opa_metric: 0.6582 - loss: 4.4905 - regularization_loss: 0.0000e+00 - total_loss: 4.4905\n",
      "Epoch 4/10\n",
      "125/125 [==============================] - 428s 3s/step - ndcg_metric: 0.7996 - mrr_metric: 0.9509 - opa_metric: 0.6634 - loss: 4.4857 - regularization_loss: 0.0000e+00 - total_loss: 4.4857\n",
      "Epoch 5/10\n",
      "125/125 [==============================] - 428s 3s/step - ndcg_metric: 0.8024 - mrr_metric: 0.9517 - opa_metric: 0.6660 - loss: 4.4371 - regularization_loss: 0.0000e+00 - total_loss: 4.4371\n",
      "Epoch 6/10\n",
      "125/125 [==============================] - 430s 3s/step - ndcg_metric: 0.8097 - mrr_metric: 0.9576 - opa_metric: 0.6801 - loss: 4.3919 - regularization_loss: 0.0000e+00 - total_loss: 4.3919\n",
      "Epoch 7/10\n",
      "125/125 [==============================] - 433s 3s/step - ndcg_metric: 0.8103 - mrr_metric: 0.9599 - opa_metric: 0.6813 - loss: 4.3672 - regularization_loss: 0.0000e+00 - total_loss: 4.3672\n",
      "Epoch 8/10\n",
      "125/125 [==============================] - 430s 3s/step - ndcg_metric: 0.8121 - mrr_metric: 0.9599 - opa_metric: 0.6836 - loss: 4.3364 - regularization_loss: 0.0000e+00 - total_loss: 4.3364\n",
      "Epoch 9/10\n",
      "125/125 [==============================] - 428s 3s/step - ndcg_metric: 0.8101 - mrr_metric: 0.9573 - opa_metric: 0.6817 - loss: 4.3531 - regularization_loss: 0.0000e+00 - total_loss: 4.3531\n",
      "Epoch 10/10\n",
      "125/125 [==============================] - 429s 3s/step - ndcg_metric: 0.8130 - mrr_metric: 0.9607 - opa_metric: 0.6854 - loss: 4.3518 - regularization_loss: 0.0000e+00 - total_loss: 4.3518\n",
      "CPU times: user 1h 11min 8s, sys: 2min 29s, total: 1h 13min 37s\n",
      "Wall time: 1h 11min 49s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(train_dataset, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 308,
   "id": "30db616b-ef12-4355-bdfd-2c4cc965d8b1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 6 μs, sys: 0 ns, total: 6 μs\n",
      "Wall time: 13.4 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 309,
   "id": "2bfd8c8a-0b1c-4fc3-9169-b92a208f2b7e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n",
      "tf.Tensor([670.69476 671.15436 671.4244  670.23047 670.9713 ], shape=(5,), dtype=float32) tf.Tensor([2 3 4 0 1], shape=(5,), dtype=int32) tf.Tensor([1 3 4 0 2], shape=(5,), dtype=int32)\n",
      "tf.Tensor([671.1575  671.1157  671.0053  669.87726 671.1564 ], shape=(5,), dtype=float32) tf.Tensor([4 3 1 0 2], shape=(5,), dtype=int32) tf.Tensor([4 2 1 0 3], shape=(5,), dtype=int32)\n",
      "tf.Tensor([670.7773  671.0239  670.0962  670.81885 671.05237], shape=(5,), dtype=float32) tf.Tensor([2 3 1 0 4], shape=(5,), dtype=int32) tf.Tensor([1 3 0 2 4], shape=(5,), dtype=int32)\n",
      "CPU times: user 5.27 s, sys: 375 ms, total: 5.64 s\n",
      "Wall time: 4.58 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = train_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1ff515c-5659-4903-9751-f202e42bd933",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "test_metrics = model.evaluate(test_images, test_labels, return_dict=True)\n",
    "print(f\"Test metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56eaad58-948e-4094-b2f9-7cfa45a504e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "val_metrics = model.evaluate(val_images, val_labels, return_dict=True)\n",
    "print(f\"Val metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 310,
   "id": "f1c6a4dd-9572-4028-941b-ff73c5dab185",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(68992, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60ba1540>, 139759996331840), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(68992, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60ba1540>, 139759996331840), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c889af880>, 139759996327760), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c889af880>, 139759996327760), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c889acca0>, 139758922344096), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c889acca0>, 139758922344096), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c889aca60>, 139758922346416), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c889aca60>, 139758922346416), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60b7d390>, 139758922338416), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60b7d390>, 139758922338416), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60b7f280>, 139758922341376), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60b7f280>, 139758922341376), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca307e2f0>, 139759989131024), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca307e2f0>, 139759989131024), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b2f1c90>, 139759989131104), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b2f1c90>, 139759989131104), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b2f28f0>, 139759989139664), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b2f28f0>, 139759989139664), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b2f36a0>, 139759989143584), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b2f36a0>, 139759989143584), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(68992, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60ba1540>, 139759996331840), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(68992, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60ba1540>, 139759996331840), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c889af880>, 139759996327760), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c889af880>, 139759996327760), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c889acca0>, 139758922344096), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c889acca0>, 139758922344096), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c889aca60>, 139758922346416), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c889aca60>, 139758922346416), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60b7d390>, 139758922338416), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60b7d390>, 139758922338416), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60b7f280>, 139758922341376), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60b7f280>, 139758922341376), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca307e2f0>, 139759989131024), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca307e2f0>, 139759989131024), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b2f1c90>, 139759989131104), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b2f1c90>, 139759989131104), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b2f28f0>, 139759989139664), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b2f28f0>, 139759989139664), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b2f36a0>, 139759989143584), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b2f36a0>, 139759989143584), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/EfficientNetB2RankingModel_20240614_142941_freezed_0.01/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/EfficientNetB2RankingModel_20240614_142941_freezed_0.01/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved in tesla_saved_models/EfficientNetB2RankingModel_20240614_142941_freezed_0.01 as EfficientNetB2RankingModel_20240614_142941_freezed_0.01\n",
      "CPU times: user 24.6 s, sys: 941 ms, total: 25.6 s\n",
      "Wall time: 25.4 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if SAVE_MODEL:\n",
    "    save_model_with_timestamp_and_lr(model, lr, freezed=\"freezed\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c95e7809-309c-4823-9076-6ab22e614551",
   "metadata": {},
   "source": [
    "### Evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63fed100-5304-4f83-94e0-c3f3ad0a8cbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = tf.keras.models.load_model('saved_models/EfficientNetB2RankingModel_20240612_112012', \n",
    "                                          custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd02a177-5ca6-4dad-9fb5-c93ae914e28f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c3aafd2-dcc4-4cf1-a0b0-02f04bbf33b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "070afdd5-dcee-4237-b786-134a1f6d6e15",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51f3330b-ccb7-46e1-a969-afc8f379d9f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef18fa04-e673-4f30-9c6f-cf898df76379",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "846e36e2-ed64-4a69-a5a8-fda00eaaee14",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72208d5e-4e4e-4e99-8790-ec456b706e9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13efde8b-3a3d-4799-811f-1f01b01e9565",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = test_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca07118f-7996-4f6b-b8a4-46ac603bde99",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "plot_metrics(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cca242e-d6fb-40cc-9b91-e5433e63a305",
   "metadata": {},
   "source": [
    "## MobileNetV2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9efe28d8-3805-439c-bd21-fccc0b2cedd5",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2565e69-3ffd-47c4-8b1e-04269a25423b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "loss = tfr.keras.losses.ListMLELoss()\n",
    "model = MobileNetV2RankingModel(loss)\n",
    "lr = 1e-5\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "118dcaf6-bafa-4869-a400-f0e9f0d4e32f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "batch = train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a246d008-3265-448c-a492-bf4920382748",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "model.predict(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cba90a9-5aec-4eaa-bf06-947f82aaa478",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17e8d291-cf04-4a00-8c14-6bee7894d5e4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(train_dataset, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "029a2812-6e09-412b-9af6-e6b95e9aaf1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26787416-ac98-4a70-b209-d33ed14a86ef",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = train_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ee43efd-361e-4822-a290-84d29d8983b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "test_metrics = model.evaluate(test_images, test_labels, return_dict=True)\n",
    "print(f\"Test metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f030c1d9-3166-4fb1-b269-67958b736481",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "val_metrics = model.evaluate(val_images, val_labels, return_dict=True)\n",
    "print(f\"Val metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd7722ad-9d82-4a58-9ec1-280cacef4339",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "if SAVE_MODEL:\n",
    "    save_model_with_timestamp_and_lr(model, lr, freezed=\"freezed\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e404d33-b84a-4bb3-875f-176fabce66b9",
   "metadata": {},
   "source": [
    "### Evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c81b9ff9-11b7-4ddf-bc42-52ce2e436571",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = tf.keras.models.load_model('saved_models/MobileNetV2RankingModel_20240612_011038', \n",
    "                                          custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cb210cc-8096-4a72-81f2-297697386f84",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2f927dd-81a3-4b82-99cf-2edd64d9f1a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "566949ed-22b0-42d1-a755-730ce0ee3fe0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28800293-79cf-424b-b5f2-1e33282fbd79",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e40d3108-d107-4410-941e-7473009deaba",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d48dce9d-791d-4517-a7fd-89738a0af8ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b30778d-97a2-4b8b-ba8c-923623ebb7fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "889ccb04-0b25-44f2-a00b-0cd9b76edc7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = test_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f56bddf9-e786-4b9e-be75-cc5191c9e179",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "plot_metrics(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69b44d26-7208-478c-a99a-e923327148e8",
   "metadata": {},
   "source": [
    "## VGG16"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19825fcf-7c12-4ca4-b562-c0ade23bff3f",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 239,
   "id": "3c5374cd-2510-45e3-8bf2-165c93d0014a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 275 ms, sys: 27.9 ms, total: 303 ms\n",
      "Wall time: 293 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "loss = tfr.keras.losses.ListMLELoss()\n",
    "model = VGG16RankingModel(loss)\n",
    "lr = 1e-4\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 240,
   "id": "ba073a30-c217-4a66-8130-5a47eb99564a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3.13 ms, sys: 0 ns, total: 3.13 ms\n",
      "Wall time: 1.65 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "batch = train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 241,
   "id": "93b0532b-6ad9-4d8d-96c5-bcb63b291cfa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:5 out of the last 129 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f1c979dd750> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:5 out of the last 129 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f1c979dd750> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 7s 7s/step\n",
      "CPU times: user 7.22 s, sys: 564 ms, total: 7.79 s\n",
      "Wall time: 7.5 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-13 23:43:55.211277: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15390714029690532465\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[ 0.23688674,  0.269287  , -0.03482419,  0.3778296 ,  0.24500519],\n",
       "       [ 0.29883084,  0.25655425,  0.21707058,  0.15067598,  0.13689235],\n",
       "       [ 0.15531474,  0.11916703,  0.42833632,  0.19256277,  0.12866649],\n",
       "       [ 0.2600698 , -0.0109174 ,  0.13296297,  0.27727035,  0.10815731],\n",
       "       [ 0.30631405,  0.44057435,  0.17988461,  0.4656174 ,  0.24805163],\n",
       "       [ 0.13529018,  0.30076703,  0.18921301,  0.24148856,  0.14383778],\n",
       "       [ 0.00733489,  0.09934801,  0.04122813,  0.1105479 ,  0.2009179 ],\n",
       "       [ 0.24115282,  0.26656887,  0.13941897,  0.230727  ,  0.00370467],\n",
       "       [ 0.16144213,  0.5278568 ,  0.11020926,  0.10412404,  0.29114157],\n",
       "       [ 0.30201635,  0.25502402,  0.5772911 ,  0.09063175,  0.58642304],\n",
       "       [ 0.15189691,  0.14661287,  0.02176464,  0.32809204,  0.24007088],\n",
       "       [ 0.4502493 ,  0.65245545,  0.09593384,  0.05313045,  0.40355688],\n",
       "       [ 0.12720814,  0.23734772,  0.02563155,  0.20102972,  0.24783301],\n",
       "       [ 0.2086528 ,  0.15854307,  0.23629479, -0.0363265 ,  0.1778763 ],\n",
       "       [ 0.34591144,  0.25370684,  0.23561583,  0.2084949 ,  0.11075954],\n",
       "       [ 0.17764887,  0.21684632, -0.04103452,  0.15650027,  0.05164635],\n",
       "       [ 0.32560566,  0.30073312,  0.40607792,  0.2086425 ,  0.30389625],\n",
       "       [ 0.28958958,  0.19481763,  0.15696597,  0.17831613, -0.04088175],\n",
       "       [ 0.09441225,  0.36265847,  0.23993894,  0.06896776,  0.26331964],\n",
       "       [ 0.10947068,  0.3226364 ,  0.14862294,  0.18417333,  0.06680377],\n",
       "       [-0.0682701 , -0.04377615,  0.2591668 ,  0.09497362,  0.05907619],\n",
       "       [ 0.1872338 ,  0.5348089 ,  0.26788875,  0.00416011,  0.28526396],\n",
       "       [ 0.02371711,  0.19841653,  0.31326625,  0.03719243,  0.19770907],\n",
       "       [ 0.19127998,  0.09205037,  0.00171719,  0.24925314,  0.21914893],\n",
       "       [ 0.30496103,  0.07090488,  0.51301223,  0.26709723,  0.22266479],\n",
       "       [ 0.01799698,  0.18815312,  0.23896736,  0.03914312,  0.28013417],\n",
       "       [ 0.30820847,  0.3121579 ,  0.35672623,  0.29765502,  0.23578714],\n",
       "       [ 0.29378617,  0.12519105,  0.1515439 ,  0.27455541,  0.34391165],\n",
       "       [-0.08628979,  0.35823348,  0.21511377,  0.4910087 ,  0.28604394],\n",
       "       [ 0.3685879 ,  0.19155945,  0.3061325 ,  0.25185347,  0.23089947]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 241,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.predict(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 242,
   "id": "5daea2d1-958d-4c99-82b7-14ad84f21403",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vgg16_ranking_model\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " vgg16 (Functional)          (None, 7, 7, 512)         14714688  \n",
      "                                                                 \n",
      " flatten_7 (Flatten)         multiple                  0         \n",
      "                                                                 \n",
      " sequential_14 (Sequential)  (None, 64)                13018048  \n",
      "                                                                 \n",
      " sequential_15 (Sequential)  (None, 1)                 65        \n",
      "                                                                 \n",
      " ranking_7 (Ranking)         multiple                  0 (unused)\n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 27732801 (105.79 MB)\n",
      "Trainable params: 27732801 (105.79 MB)\n",
      "Non-trainable params: 0 (0.00 Byte)\n",
      "_________________________________________________________________\n",
      "CPU times: user 16.9 ms, sys: 7.73 ms, total: 24.6 ms\n",
      "Wall time: 21.4 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 243,
   "id": "9dc2ab8b-1435-4e54-8d8c-94c3609c634b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "125/125 [==============================] - 520s 4s/step - ndcg_metric: 0.8401 - mrr_metric: 0.9711 - opa_metric: 0.7221 - loss: 4.0279 - regularization_loss: 0.0000e+00 - total_loss: 4.0279\n",
      "Epoch 2/10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-13 23:52:34.892878: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7878071529173532157\n",
      "2024-06-13 23:52:34.892918: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7846877891590815311\n",
      "2024-06-13 23:52:34.892933: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12638169692605933059\n",
      "2024-06-13 23:52:34.892946: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12640315801025111887\n",
      "2024-06-13 23:52:34.892958: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3003783417270316761\n",
      "2024-06-13 23:52:34.892968: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10881550688257158947\n",
      "2024-06-13 23:52:34.892979: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15055576340762988925\n",
      "2024-06-13 23:52:34.892990: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15390714029690532465\n",
      "2024-06-13 23:52:34.893001: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8644040634962522911\n",
      "2024-06-13 23:52:34.893011: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13088629136610595653\n",
      "2024-06-13 23:52:34.893021: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10716319904263870411\n",
      "2024-06-13 23:52:34.893032: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3999045217104320699\n",
      "2024-06-13 23:52:34.893043: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11904728462121908945\n",
      "2024-06-13 23:52:34.893050: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7335259745625079133\n",
      "2024-06-13 23:52:34.893057: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5400568496618569103\n",
      "2024-06-13 23:52:34.893065: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5864308722232385411\n",
      "2024-06-13 23:52:34.893078: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15221140525949403710\n",
      "2024-06-13 23:52:34.893085: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7289713512421860250\n",
      "2024-06-13 23:52:34.893092: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4740728694794175656\n",
      "2024-06-13 23:52:34.893099: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8420771768681022062\n",
      "2024-06-13 23:52:34.893106: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7728875926801548036\n",
      "2024-06-13 23:52:34.893112: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12664450691505465382\n",
      "2024-06-13 23:52:34.893121: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7646603089293452164\n",
      "2024-06-13 23:52:34.893130: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8117779196292883356\n",
      "2024-06-13 23:52:34.893136: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9462556850835320504\n",
      "2024-06-13 23:52:34.893143: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14954749817996095652\n",
      "2024-06-13 23:52:34.893150: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 18152611565809742690\n",
      "2024-06-13 23:52:34.893156: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17586904816442374370\n",
      "2024-06-13 23:52:34.893163: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8871757161215597482\n",
      "2024-06-13 23:52:34.893170: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2117313457495514004\n",
      "2024-06-13 23:52:34.893177: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9158975927443619112\n",
      "2024-06-13 23:52:34.893230: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14495122275425991844\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 496s 4s/step - ndcg_metric: 0.8434 - mrr_metric: 0.9676 - opa_metric: 0.6465 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 3/10\n",
      "125/125 [==============================] - 491s 4s/step - ndcg_metric: 0.7178 - mrr_metric: 0.8983 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 4/10\n",
      "125/125 [==============================] - 492s 4s/step - ndcg_metric: 0.7166 - mrr_metric: 0.8973 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 5/10\n",
      "125/125 [==============================] - 494s 4s/step - ndcg_metric: 0.7200 - mrr_metric: 0.8961 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 6/10\n",
      "125/125 [==============================] - 493s 4s/step - ndcg_metric: 0.7128 - mrr_metric: 0.9005 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 7/10\n",
      "125/125 [==============================] - 493s 4s/step - ndcg_metric: 0.7181 - mrr_metric: 0.8897 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 8/10\n",
      "125/125 [==============================] - 493s 4s/step - ndcg_metric: 0.7160 - mrr_metric: 0.9020 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 9/10\n",
      "125/125 [==============================] - 493s 4s/step - ndcg_metric: 0.7195 - mrr_metric: 0.9000 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 10/10\n",
      "125/125 [==============================] - 497s 4s/step - ndcg_metric: 0.7178 - mrr_metric: 0.9007 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "CPU times: user 1h 23min 4s, sys: 2min 51s, total: 1h 25min 56s\n",
      "Wall time: 1h 22min 43s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(train_dataset, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 244,
   "id": "124b348c-7859-4507-abc3-1d5d22a94bb6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 7 μs, sys: 0 ns, total: 7 μs\n",
      "Wall time: 13.6 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 245,
   "id": "15a03d26-87f6-4361-bf5c-a3bc14d29fef",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n",
      "tf.Tensor([nan nan nan nan nan], shape=(5,), dtype=float32) tf.Tensor([2 1 0 4 3], shape=(5,), dtype=int32) tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)\n",
      "tf.Tensor([nan nan nan nan nan], shape=(5,), dtype=float32) tf.Tensor([2 1 0 4 3], shape=(5,), dtype=int32) tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)\n",
      "tf.Tensor([nan nan nan nan nan], shape=(5,), dtype=float32) tf.Tensor([1 2 0 4 3], shape=(5,), dtype=int32) tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)\n",
      "CPU times: user 30 s, sys: 486 ms, total: 30.5 s\n",
      "Wall time: 20.5 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = train_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "9a73e411-81c2-4c10-b54d-a60f534dc7ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-19 07:27:54.878855: W external/local_tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 3.01GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.\n",
      "2024-05-19 07:27:54.878888: W tensorflow/core/kernels/gpu_utils.cc:54] Failed to allocate memory for convolution redzone checking; skipping this check. This is benign and only means that we won't check cudnn for out-of-bounds reads and writes. This message will only be printed once.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 23s 23s/step - ndcg_metric: 0.9074 - mrr_metric: 1.0000 - opa_metric: 0.7980 - loss: 7.2471 - regularization_loss: 0.0000e+00 - total_loss: 7.2471\n",
      "Test metrics: {'ndcg_metric': 0.9073922038078308, 'mrr_metric': 1.0, 'opa_metric': 0.7979999780654907, 'loss': 7.247115612030029, 'regularization_loss': 0, 'total_loss': 7.247115612030029}\n"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(test_images, test_labels, return_dict=True)\n",
    "print(f\"Test metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "f22d7859-959b-4f6a-b3a2-a09906d152ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 489ms/step - ndcg_metric: 0.9223 - mrr_metric: 1.0000 - opa_metric: 0.8500 - loss: 6.1223 - regularization_loss: 0.0000e+00 - total_loss: 6.1223\n",
      "Val metrics: {'ndcg_metric': 0.9073922038078308, 'mrr_metric': 1.0, 'opa_metric': 0.7979999780654907, 'loss': 7.247115612030029, 'regularization_loss': 0, 'total_loss': 7.247115612030029}\n"
     ]
    }
   ],
   "source": [
    "val_metrics = model.evaluate(val_images, val_labels, return_dict=True)\n",
    "print(f\"Val metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 246,
   "id": "7a1b9b60-df14-499e-ad30-d3945e92c102",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca0567670>, 139760867740528), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca0567670>, 139760867740528), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ded7bd1e0>, 139760045863632), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ded7bd1e0>, 139760045863632), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6bc252a0>, 139760867738848), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6bc252a0>, 139760867738848), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6bc27bb0>, 139760867746928), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6bc27bb0>, 139760867746928), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6bc27880>, 139760388035872), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6bc27880>, 139760388035872), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca8e69ae0>, 139760388032432), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca8e69ae0>, 139760388032432), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ded99b550>, 139759742016192), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ded99b550>, 139759742016192), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1da61b6110>, 139759742024192), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1da61b6110>, 139759742024192), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cc82af2e0>, 139760867749408), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cc82af2e0>, 139760867749408), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cdddab8b0>, 139760867743248), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cdddab8b0>, 139760867743248), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca0567670>, 139760867740528), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca0567670>, 139760867740528), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ded7bd1e0>, 139760045863632), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ded7bd1e0>, 139760045863632), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6bc252a0>, 139760867738848), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6bc252a0>, 139760867738848), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6bc27bb0>, 139760867746928), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6bc27bb0>, 139760867746928), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6bc27880>, 139760388035872), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c6bc27880>, 139760388035872), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca8e69ae0>, 139760388032432), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca8e69ae0>, 139760388032432), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ded99b550>, 139759742016192), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ded99b550>, 139759742016192), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1da61b6110>, 139759742024192), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1da61b6110>, 139759742024192), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cc82af2e0>, 139760867749408), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cc82af2e0>, 139760867749408), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cdddab8b0>, 139760867743248), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cdddab8b0>, 139760867743248), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/VGG16RankingModel_20240614_010659_unfreezed_0.0001/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/VGG16RankingModel_20240614_010659_unfreezed_0.0001/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved in tesla_saved_models/VGG16RankingModel_20240614_010659_unfreezed_0.0001 as VGG16RankingModel_20240614_010659_unfreezed_0.0001\n",
      "CPU times: user 2.37 s, sys: 361 ms, total: 2.73 s\n",
      "Wall time: 2.68 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if SAVE_MODEL:\n",
    "    save_model_with_timestamp_and_lr(model, lr, freezed=\"freezed\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cc4d7b7-6abb-42d5-b85b-735e6e1b2e64",
   "metadata": {},
   "source": [
    "### Evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89f9004f-8fc0-45e5-b236-399766a13b73",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = tf.keras.models.load_model('tesla_saved_models/VGG16RankingModel_20240614_054453_freezed_0.0001', \n",
    "                                          custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7b12f54-d462-4ae2-8dbe-ea2beb8d3973",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afd79225-cb32-4c89-947a-979d53a39f71",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51b48f86-d5d7-42de-90ef-fe59b3a82ae0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "514cc75a-695e-473c-a692-08b192dd7784",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b841f871-dfde-44d5-ab32-017a54bc60bf",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19a886fd-43b4-4fc1-98b4-37877b3ab383",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4d63e04-7655-47b0-93c0-57524d5e4504",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b3465f0-c6d3-416c-be29-78b92ba9d510",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = test_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84969093-5714-4beb-a811-917b6e1763ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "plot_metrics(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "174bd161-6847-46c3-a0eb-82802bd46a4c",
   "metadata": {},
   "source": [
    "## VGG19"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee817ed7-b1b1-4a0d-b1f4-fefb6f5f9aeb",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 247,
   "id": "2af1b2ca-14ee-44b0-a7d0-50155a58d0ef",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 334 ms, sys: 80.6 ms, total: 414 ms\n",
      "Wall time: 405 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "loss = tfr.keras.losses.ListMLELoss()\n",
    "model = VGG19RankingModel(loss)\n",
    "lr = 1e-4\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 248,
   "id": "8e6e255d-de07-488c-ac82-95a274b16b6e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3.17 ms, sys: 132 μs, total: 3.3 ms\n",
      "Wall time: 1.48 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "batch = train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 249,
   "id": "769a060d-0d59-42a7-aef8-85218edd9c55",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 7s 7s/step\n",
      "CPU times: user 7.15 s, sys: 399 ms, total: 7.55 s\n",
      "Wall time: 7.29 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 01:07:09.959466: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14764271650021614035\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[ 0.24851231,  0.353503  ,  0.28924406,  0.22375408,  0.13541214],\n",
       "       [ 0.31440052,  0.23557281,  0.18474147,  0.22031453,  0.2968847 ],\n",
       "       [ 0.20243444,  0.23200569,  0.25948408,  0.21912439,  0.15240438],\n",
       "       [ 0.33570775,  0.15599027,  0.16067326,  0.14401233,  0.14870936],\n",
       "       [ 0.23588142,  0.2846503 ,  0.20993447,  0.2699514 ,  0.455951  ],\n",
       "       [ 0.15908134,  0.17388204,  0.28902015,  0.19656968,  0.17617184],\n",
       "       [ 0.2672039 ,  0.1601595 ,  0.28014863,  0.32674482,  0.34671682],\n",
       "       [ 0.24911776,  0.32620266,  0.05495804,  0.00632505,  0.18030083],\n",
       "       [ 0.14215411,  0.18230128,  0.07621626,  0.20018819,  0.2544415 ],\n",
       "       [ 0.27248716,  0.1017893 ,  0.41483134,  0.30920961,  0.39430666],\n",
       "       [ 0.16098009,  0.20512353,  0.30328655,  0.10231443,  0.19844982],\n",
       "       [ 0.34305683,  0.16081622,  0.15379712,  0.31129992,  0.25466862],\n",
       "       [ 0.1925501 ,  0.23970169,  0.13796426,  0.25646576,  0.16706195],\n",
       "       [ 0.15582554,  0.18731698,  0.14971513,  0.10735915,  0.14434034],\n",
       "       [ 0.28461158,  0.26142383,  0.12642473,  0.15036878,  0.2068874 ],\n",
       "       [ 0.1291858 ,  0.09499401,  0.07976632,  0.3914611 ,  0.20463125],\n",
       "       [ 0.27607325,  0.27993906,  0.20050853,  0.25166652,  0.12480047],\n",
       "       [ 0.04869433,  0.16894802,  0.19080041,  0.17609543,  0.15893447],\n",
       "       [-0.00727175,  0.14082408,  0.39246088,  0.28866622,  0.18955664],\n",
       "       [ 0.14441782,  0.18955255,  0.2441347 ,  0.16568397,  0.15565935],\n",
       "       [ 0.04393268,  0.2698086 ,  0.27902162,  0.1998025 ,  0.06445878],\n",
       "       [ 0.20801088,  0.30805165,  0.19335562,  0.17959586,  0.23599008],\n",
       "       [ 0.13437301,  0.13182712,  0.37523267,  0.13980347,  0.29226923],\n",
       "       [ 0.09946127,  0.18742481,  0.09360301,  0.16984597,  0.1409869 ],\n",
       "       [ 0.16160685,  0.0707093 ,  0.07503207,  0.27276325,  0.18224494],\n",
       "       [ 0.04399318,  0.0973418 ,  0.15604551,  0.12890756,  0.16124237],\n",
       "       [ 0.03133535,  0.08822724,  0.21795887,  0.1834386 ,  0.21298674],\n",
       "       [ 0.14528672,  0.24296999,  0.16093922,  0.08956078,  0.22030684],\n",
       "       [ 0.1509003 ,  0.37600386,  0.4584034 ,  0.22183922,  0.35968548],\n",
       "       [ 0.18208738,  0.25961205,  0.40548363,  0.17767355,  0.26698136]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 249,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.predict(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 250,
   "id": "1ef9cc5f-cc8f-48da-b3e0-f9a6ffa0254d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vgg19_ranking_model\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " vgg19 (Functional)          (None, 7, 7, 512)         20024384  \n",
      "                                                                 \n",
      " flatten_8 (Flatten)         multiple                  0         \n",
      "                                                                 \n",
      " sequential_16 (Sequential)  (None, 64)                13018048  \n",
      "                                                                 \n",
      " sequential_17 (Sequential)  (None, 1)                 65        \n",
      "                                                                 \n",
      " ranking_8 (Ranking)         multiple                  0 (unused)\n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 33042497 (126.05 MB)\n",
      "Trainable params: 33042497 (126.05 MB)\n",
      "Non-trainable params: 0 (0.00 Byte)\n",
      "_________________________________________________________________\n",
      "CPU times: user 22 ms, sys: 3.76 ms, total: 25.8 ms\n",
      "Wall time: 23.2 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 251,
   "id": "a0ba913e-beb7-4d5b-ae66-5a863a68d0ba",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "125/125 [==============================] - 513s 4s/step - ndcg_metric: 0.8358 - mrr_metric: 0.9668 - opa_metric: 0.7179 - loss: 4.0895 - regularization_loss: 0.0000e+00 - total_loss: 4.0895\n",
      "Epoch 2/10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 01:15:43.272732: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1088254430448484430\n",
      "2024-06-14 01:15:43.272818: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6106416586379984997\n",
      "2024-06-14 01:15:43.272844: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14764271650021614035\n",
      "2024-06-14 01:15:43.272864: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14785363395867255621\n",
      "2024-06-14 01:15:43.272881: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4689267766182724855\n",
      "2024-06-14 01:15:43.272898: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7877673172003818391\n",
      "2024-06-14 01:15:43.272914: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 676314631351954365\n",
      "2024-06-14 01:15:43.272931: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4612173784223714079\n",
      "2024-06-14 01:15:43.272948: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6802962521415675681\n",
      "2024-06-14 01:15:43.272964: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 188325259605078259\n",
      "2024-06-14 01:15:43.272980: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 443208944460132669\n",
      "2024-06-14 01:15:43.272996: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4855144996701696305\n",
      "2024-06-14 01:15:43.273013: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17758073706722759615\n",
      "2024-06-14 01:15:43.273029: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2595200501857358243\n",
      "2024-06-14 01:15:43.273045: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 654587162800092763\n",
      "2024-06-14 01:15:43.273062: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15806481289073784875\n",
      "2024-06-14 01:15:43.273078: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1618082295076665497\n",
      "2024-06-14 01:15:43.273097: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8943907292632349059\n",
      "2024-06-14 01:15:43.273114: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15941820913073298467\n",
      "2024-06-14 01:15:43.273135: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11537329754499644795\n",
      "2024-06-14 01:15:43.273160: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15913402631510866199\n",
      "2024-06-14 01:15:43.273198: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8921556161878475176\n",
      "2024-06-14 01:15:43.273222: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8452474876992302900\n",
      "2024-06-14 01:15:43.273239: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15184123812925850498\n",
      "2024-06-14 01:15:43.273255: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15697216796008838270\n",
      "2024-06-14 01:15:43.273272: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12684653093568628192\n",
      "2024-06-14 01:15:43.273301: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 481864624172480188\n",
      "2024-06-14 01:15:43.273331: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10917030791734319532\n",
      "2024-06-14 01:15:43.273364: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14874925800357431768\n",
      "2024-06-14 01:15:43.273395: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1499985239587855712\n",
      "2024-06-14 01:15:43.273426: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1657733684237282840\n",
      "2024-06-14 01:15:43.273540: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17024944895194967702\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 506s 4s/step - ndcg_metric: 0.8600 - mrr_metric: 0.9769 - opa_metric: 0.7503 - loss: 3.8229 - regularization_loss: 0.0000e+00 - total_loss: 3.8229\n",
      "Epoch 3/10\n",
      "125/125 [==============================] - 508s 4s/step - ndcg_metric: 0.8762 - mrr_metric: 0.9831 - opa_metric: 0.7749 - loss: 3.5918 - regularization_loss: 0.0000e+00 - total_loss: 3.5918\n",
      "Epoch 4/10\n",
      "125/125 [==============================] - 504s 4s/step - ndcg_metric: 0.7555 - mrr_metric: 0.9228 - opa_metric: 0.1752 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 5/10\n",
      "125/125 [==============================] - 503s 4s/step - ndcg_metric: 0.7207 - mrr_metric: 0.8960 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 6/10\n",
      "125/125 [==============================] - 504s 4s/step - ndcg_metric: 0.7176 - mrr_metric: 0.9032 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 7/10\n",
      "125/125 [==============================] - 503s 4s/step - ndcg_metric: 0.7188 - mrr_metric: 0.9061 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 8/10\n",
      "125/125 [==============================] - 504s 4s/step - ndcg_metric: 0.7212 - mrr_metric: 0.8952 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 9/10\n",
      "125/125 [==============================] - 505s 4s/step - ndcg_metric: 0.7223 - mrr_metric: 0.9016 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 10/10\n",
      "125/125 [==============================] - 506s 4s/step - ndcg_metric: 0.7147 - mrr_metric: 0.8980 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "CPU times: user 1h 24min 25s, sys: 2min 42s, total: 1h 27min 7s\n",
      "Wall time: 1h 24min 18s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(train_dataset, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 252,
   "id": "d30f6177-e67c-4477-af6d-b129232eae87",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 6 μs, sys: 0 ns, total: 6 μs\n",
      "Wall time: 12.9 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 253,
   "id": "a5de71ac-6f35-43c4-a472-03bdbf564498",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n",
      "tf.Tensor([nan nan nan nan nan], shape=(5,), dtype=float32) tf.Tensor([4 3 2 0 1], shape=(5,), dtype=int32) tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)\n",
      "tf.Tensor([nan nan nan nan nan], shape=(5,), dtype=float32) tf.Tensor([3 2 0 1 4], shape=(5,), dtype=int32) tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)\n",
      "tf.Tensor([nan nan nan nan nan], shape=(5,), dtype=float32) tf.Tensor([4 3 0 2 1], shape=(5,), dtype=int32) tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)\n",
      "CPU times: user 5.49 s, sys: 168 ms, total: 5.65 s\n",
      "Wall time: 4.94 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = train_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "cd55f1d2-deaa-450e-ad29-0e6b43140492",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 1s 964ms/step - ndcg_metric: 0.9098 - mrr_metric: 1.0000 - opa_metric: 0.8140 - loss: 7.4537 - regularization_loss: 0.0000e+00 - total_loss: 7.4537\n",
      "Test metrics: {'ndcg_metric': 0.9097903966903687, 'mrr_metric': 1.0, 'opa_metric': 0.8140000104904175, 'loss': 7.453701019287109, 'regularization_loss': 0, 'total_loss': 7.453701019287109}\n"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(test_images, test_labels, return_dict=True)\n",
    "print(f\"Test metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "f2aa0967-e837-44bf-a1df-c2bf9c2338ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 1s 556ms/step - ndcg_metric: 0.9357 - mrr_metric: 1.0000 - opa_metric: 0.8420 - loss: 7.0343 - regularization_loss: 0.0000e+00 - total_loss: 7.0343\n",
      "Val metrics: {'ndcg_metric': 0.9097903966903687, 'mrr_metric': 1.0, 'opa_metric': 0.8140000104904175, 'loss': 7.453701019287109, 'regularization_loss': 0, 'total_loss': 7.453701019287109}\n"
     ]
    }
   ],
   "source": [
    "val_metrics = model.evaluate(val_images, val_labels, return_dict=True)\n",
    "print(f\"Val metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 254,
   "id": "471673aa-a6d9-41ae-a08e-5f5cce1d70a9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9c17ad10>, 139760908139648), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9c17ad10>, 139760908139648), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca050b400>, 139760916953440), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca050b400>, 139760916953440), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca06e4280>, 139760916155248), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca06e4280>, 139760916155248), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bcf5900>, 139760916153968), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bcf5900>, 139760916153968), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad1ba0>, 139760916161968), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad1ba0>, 139760916161968), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad1c00>, 139760916108656), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad1c00>, 139760916108656), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad3460>, 139760916579232), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad3460>, 139760916579232), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad0640>, 139760916581312), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad0640>, 139760916581312), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97431db0>, 139760916157888), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97431db0>, 139760916157888), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c945c66b0>, 139760916157248), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c945c66b0>, 139760916157248), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9c17ad10>, 139760908139648), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9c17ad10>, 139760908139648), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca050b400>, 139760916953440), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca050b400>, 139760916953440), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca06e4280>, 139760916155248), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca06e4280>, 139760916155248), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bcf5900>, 139760916153968), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bcf5900>, 139760916153968), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad1ba0>, 139760916161968), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad1ba0>, 139760916161968), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad1c00>, 139760916108656), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad1c00>, 139760916108656), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad3460>, 139760916579232), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad3460>, 139760916579232), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad0640>, 139760916581312), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8bad0640>, 139760916581312), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97431db0>, 139760916157888), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97431db0>, 139760916157888), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c945c66b0>, 139760916157248), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c945c66b0>, 139760916157248), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/VGG19RankingModel_20240614_023133_unfreezed_0.0001/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/VGG19RankingModel_20240614_023133_unfreezed_0.0001/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved in tesla_saved_models/VGG19RankingModel_20240614_023133_unfreezed_0.0001 as VGG19RankingModel_20240614_023133_unfreezed_0.0001\n",
      "CPU times: user 2.62 s, sys: 478 ms, total: 3.1 s\n",
      "Wall time: 3.05 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if SAVE_MODEL:\n",
    "    save_model_with_timestamp_and_lr(model, lr, freezed=\"freezed\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "086c2116-c3ef-4548-9e57-c85fbfd6ae59",
   "metadata": {},
   "source": [
    "### Evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e052cfb-ce50-4fa3-8ee2-9a05b7476859",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = tf.keras.models.load_model('saved_models/VGG19RankingModel_20240612_111927', \n",
    "                                          custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dc52949-638c-4033-90d8-7d0881e245a3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f207e214-9fb2-46ea-9ef5-7c702b63f7ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15234b12-dcc4-41fd-82fd-a2a6a3cb1f28",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34aa7512-dbe7-404e-8d92-e10ba92d9920",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fedd65f-ddd3-4d65-a88d-db46b3e71e01",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d29f034f-6e65-4770-aa5b-82a2d2c71478",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e13952d2-72f2-496d-8af3-f50a4a100d50",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e7662ba-4d64-48b9-8612-11d6aa4ec778",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = test_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8478c6a1-6f2c-40a9-b674-0a1dd35d6689",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "plot_metrics(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "430ce73c-5823-4ccc-8c90-320b55f68f43",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Training models & evaluation (with unfreezing)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b42d02f-12fe-49cf-9aa8-15b38fb57e62",
   "metadata": {},
   "source": [
    "## ResNet50"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "826bcfac-6144-4c04-aa34-b4d3e42e0e52",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 271,
   "id": "162e09de-9008-4d49-9568-9a40e356912b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.37 s, sys: 96 ms, total: 1.47 s\n",
      "Wall time: 1.42 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "loss = tfr.keras.losses.ListMLELoss()\n",
    "model = ResNet50RankingModel(loss, trainable=True)\n",
    "lr = 1e-3\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 272,
   "id": "6724140b-207d-4856-a4d9-42762f4562b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3.4 ms, sys: 0 ns, total: 3.4 ms\n",
      "Wall time: 1.51 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "batch = train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 273,
   "id": "bc98913f-181c-4a2b-aa48-4835733c3536",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 8s 8s/step\n",
      "CPU times: user 8.21 s, sys: 385 ms, total: 8.6 s\n",
      "Wall time: 8.35 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 03:54:54.072295: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7321974310818004873\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[0.78390723, 0.58904314, 0.71437407, 0.59793407, 0.70264053],\n",
       "       [0.5543509 , 0.5356322 , 0.6065063 , 0.621851  , 0.40078956],\n",
       "       [0.48376176, 0.5598928 , 0.6118214 , 0.5325117 , 0.6085828 ],\n",
       "       [0.5474928 , 0.55523324, 0.62929595, 0.5970243 , 0.5397865 ],\n",
       "       [0.29721522, 0.4027599 , 0.7163266 , 0.19117115, 0.6029313 ],\n",
       "       [0.53861356, 0.439802  , 0.6001715 , 0.582041  , 0.48338813],\n",
       "       [0.72766787, 0.66761875, 0.70588   , 0.66652876, 0.6578241 ],\n",
       "       [0.26078257, 0.56557363, 0.8297902 , 0.4939194 , 0.6432159 ],\n",
       "       [0.5536493 , 0.59414625, 0.43728393, 0.51706314, 0.4949068 ],\n",
       "       [0.22872746, 0.49352065, 0.44016343, 0.33669043, 0.5724662 ],\n",
       "       [0.5481889 , 0.6270533 , 0.4934302 , 0.4889312 , 0.66808677],\n",
       "       [0.5521113 , 0.65625924, 0.5275736 , 0.30452695, 0.70853585],\n",
       "       [0.531919  , 0.54379505, 0.6987594 , 0.63164943, 0.64329827],\n",
       "       [0.50611925, 0.6017725 , 0.5172237 , 0.45626092, 0.60369664],\n",
       "       [0.4852317 , 0.4418251 , 0.509687  , 0.59272945, 0.62954223],\n",
       "       [0.37056714, 0.46414277, 0.40876403, 0.447555  , 0.45105296],\n",
       "       [0.50202537, 0.63108176, 0.5701081 , 0.695008  , 0.6729012 ],\n",
       "       [0.4925757 , 0.46362025, 0.51929533, 0.5145162 , 0.51926404],\n",
       "       [0.46797508, 0.5485852 , 0.39012972, 0.517347  , 0.07987411],\n",
       "       [0.37868086, 0.40902966, 0.44787562, 0.35923135, 0.52667344],\n",
       "       [0.2303682 , 0.36640522, 0.6231297 , 0.4091492 , 0.5292743 ],\n",
       "       [0.6066557 , 0.5453904 , 0.5312056 , 0.6015274 , 0.6629761 ],\n",
       "       [0.43778297, 0.37253225, 0.63403535, 0.58462226, 0.51142985],\n",
       "       [0.31220192, 0.5197493 , 0.49290776, 0.6155454 , 0.5582367 ],\n",
       "       [0.50295573, 0.44277862, 0.4691074 , 0.55699146, 0.51910555],\n",
       "       [0.53185886, 0.49023297, 0.48679963, 0.5072357 , 0.5028553 ],\n",
       "       [0.46839243, 0.49569675, 0.5075837 , 0.53713614, 0.5320697 ],\n",
       "       [0.4882828 , 0.4545255 , 0.47865295, 0.5514097 , 0.3994506 ],\n",
       "       [0.27889168, 0.3586836 , 0.48244342, 0.53456354, 0.2753312 ],\n",
       "       [0.7217092 , 0.6344151 , 0.5983685 , 0.5394919 , 0.6883004 ]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 273,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.predict(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 274,
   "id": "11f8fbf7-79e5-4067-83d2-9e58675d0857",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"res_net50_ranking_model_1\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " resnet50 (Functional)       (None, 7, 7, 2048)        23587712  \n",
      "                                                                 \n",
      " flatten_11 (Flatten)        multiple                  0         \n",
      "                                                                 \n",
      " sequential_22 (Sequential)  (None, 64)                51553216  \n",
      "                                                                 \n",
      " sequential_23 (Sequential)  (None, 1)                 65        \n",
      "                                                                 \n",
      " ranking_11 (Ranking)        multiple                  0 (unused)\n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 75140993 (286.64 MB)\n",
      "Trainable params: 75087873 (286.44 MB)\n",
      "Non-trainable params: 53120 (207.50 KB)\n",
      "_________________________________________________________________\n",
      "CPU times: user 15.7 ms, sys: 11.6 ms, total: 27.3 ms\n",
      "Wall time: 25.1 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 275,
   "id": "7f89894a-2f50-40e1-a07d-e67b38e16228",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "    125/Unknown - 493s 4s/step - ndcg_metric: 0.7600 - mrr_metric: 0.9340 - opa_metric: 0.6130 - loss: 4.8724 - regularization_loss: 0.0000e+00 - total_loss: 4.8724"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 04:03:10.312182: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12989881424137535623\n",
      "2024-06-14 04:03:10.312222: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1030990000791773353\n",
      "2024-06-14 04:03:10.312233: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4235485683151182505\n",
      "2024-06-14 04:03:10.312242: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1665617716135232833\n",
      "2024-06-14 04:03:10.312250: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15676609152484418829\n",
      "2024-06-14 04:03:10.312260: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12305647672294230075\n",
      "2024-06-14 04:03:10.312271: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11297363814829495897\n",
      "2024-06-14 04:03:10.312282: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10312653086569853007\n",
      "2024-06-14 04:03:10.312292: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10336776434770366799\n",
      "2024-06-14 04:03:10.312303: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4086343452725154643\n",
      "2024-06-14 04:03:10.312315: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15070458440930300833\n",
      "2024-06-14 04:03:10.312324: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9909194245574087951\n",
      "2024-06-14 04:03:10.312332: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 18150139875790050841\n",
      "2024-06-14 04:03:10.312339: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5477739728408999709\n",
      "2024-06-14 04:03:10.312347: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3484373842234898885\n",
      "2024-06-14 04:03:10.312354: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 606410873458643023\n",
      "2024-06-14 04:03:10.312361: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3864247054359403\n",
      "2024-06-14 04:03:10.312369: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 731991778821544941\n",
      "2024-06-14 04:03:10.312377: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5551434146676464783\n",
      "2024-06-14 04:03:10.312384: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7321974310818004873\n",
      "2024-06-14 04:03:10.312399: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12523294065456382054\n",
      "2024-06-14 04:03:10.312410: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5155015052912980536\n",
      "2024-06-14 04:03:10.312422: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3196869903247622962\n",
      "2024-06-14 04:03:10.312433: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16219902893010113946\n",
      "2024-06-14 04:03:10.312440: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 53464985064023646\n",
      "2024-06-14 04:03:10.312447: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11763179780301744058\n",
      "2024-06-14 04:03:10.312455: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8248675116470440638\n",
      "2024-06-14 04:03:10.312461: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6849874708640729452\n",
      "2024-06-14 04:03:10.312469: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 841039418179250704\n",
      "2024-06-14 04:03:10.312478: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15152595476291144700\n",
      "2024-06-14 04:03:10.312493: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6509677380059245806\n",
      "2024-06-14 04:03:10.312633: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3928258023429765196\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 496s 4s/step - ndcg_metric: 0.7600 - mrr_metric: 0.9340 - opa_metric: 0.6130 - loss: 4.8705 - regularization_loss: 0.0000e+00 - total_loss: 4.8705\n",
      "Epoch 2/10\n",
      "125/125 [==============================] - 473s 4s/step - ndcg_metric: 0.7597 - mrr_metric: 0.9379 - opa_metric: 0.6156 - loss: 4.6101 - regularization_loss: 0.0000e+00 - total_loss: 4.6101\n",
      "Epoch 3/10\n",
      "125/125 [==============================] - 471s 4s/step - ndcg_metric: 0.7582 - mrr_metric: 0.9379 - opa_metric: 0.6149 - loss: 4.6075 - regularization_loss: 0.0000e+00 - total_loss: 4.6075\n",
      "Epoch 4/10\n",
      "125/125 [==============================] - 472s 4s/step - ndcg_metric: 0.7616 - mrr_metric: 0.9372 - opa_metric: 0.6198 - loss: 4.6757 - regularization_loss: 0.0000e+00 - total_loss: 4.6757\n",
      "Epoch 5/10\n",
      "125/125 [==============================] - 472s 4s/step - ndcg_metric: 0.7663 - mrr_metric: 0.9377 - opa_metric: 0.6202 - loss: 4.6188 - regularization_loss: 0.0000e+00 - total_loss: 4.6188\n",
      "Epoch 6/10\n",
      "125/125 [==============================] - 471s 4s/step - ndcg_metric: 0.7623 - mrr_metric: 0.9375 - opa_metric: 0.6164 - loss: 4.9531 - regularization_loss: 0.0000e+00 - total_loss: 4.9531\n",
      "Epoch 7/10\n",
      "125/125 [==============================] - 472s 4s/step - ndcg_metric: 0.7288 - mrr_metric: 0.9089 - opa_metric: 0.1203 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 8/10\n",
      "125/125 [==============================] - 472s 4s/step - ndcg_metric: 0.7189 - mrr_metric: 0.9040 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 9/10\n",
      "125/125 [==============================] - 472s 4s/step - ndcg_metric: 0.7165 - mrr_metric: 0.8969 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "Epoch 10/10\n",
      "125/125 [==============================] - 472s 4s/step - ndcg_metric: 0.7153 - mrr_metric: 0.8972 - opa_metric: 0.0000e+00 - loss: nan - regularization_loss: 0.0000e+00 - total_loss: nan\n",
      "CPU times: user 1h 21min 10s, sys: 3min 15s, total: 1h 24min 25s\n",
      "Wall time: 1h 19min 4s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(train_dataset, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 276,
   "id": "e3d38282-d14f-407f-af14-86b424eaac6a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 6 μs, sys: 1 μs, total: 7 μs\n",
      "Wall time: 12.6 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 277,
   "id": "1e1a70fa-50fb-45f6-89ac-ddc4b06aa712",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n",
      "tf.Tensor([nan nan nan nan nan], shape=(5,), dtype=float32) tf.Tensor([3 2 1 4 0], shape=(5,), dtype=int32) tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)\n",
      "tf.Tensor([nan nan nan nan nan], shape=(5,), dtype=float32) tf.Tensor([4 1 3 2 0], shape=(5,), dtype=int32) tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)\n",
      "tf.Tensor([nan nan nan nan nan], shape=(5,), dtype=float32) tf.Tensor([0 2 4 1 3], shape=(5,), dtype=int32) tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)\n",
      "CPU times: user 4.73 s, sys: 145 ms, total: 4.88 s\n",
      "Wall time: 4.2 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = train_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "a53da5de-bc08-48fe-8d04-20b4b6308d77",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 8s 8s/step - ndcg_metric: 0.8861 - mrr_metric: 1.0000 - opa_metric: 0.7840 - loss: 10.1475 - regularization_loss: 0.0000e+00 - total_loss: 10.1475\n",
      "Test metrics: {'ndcg_metric': 0.8861386179924011, 'mrr_metric': 1.0, 'opa_metric': 0.7839999794960022, 'loss': 10.147453308105469, 'regularization_loss': 0, 'total_loss': 10.147453308105469}\n"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(test_images, test_labels, return_dict=True)\n",
    "print(f\"Test metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f5fdef1a-b48c-4a47-a24c-eb282d5753f8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 408ms/step - ndcg_metric: 0.8906 - mrr_metric: 1.0000 - opa_metric: 0.8020 - loss: 10.3791 - regularization_loss: 0.0000e+00 - total_loss: 10.3791\n",
      "Val metrics: {'ndcg_metric': 0.8861386179924011, 'mrr_metric': 1.0, 'opa_metric': 0.7839999794960022, 'loss': 10.147453308105469, 'regularization_loss': 0, 'total_loss': 10.147453308105469}\n"
     ]
    }
   ],
   "source": [
    "val_metrics = model.evaluate(val_images, val_labels, return_dict=True)\n",
    "print(f\"Val metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 278,
   "id": "0e5e8f80-d738-4d0e-bcd8-b178572ff36e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(100352, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97b2cc10>, 139760039361664), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(100352, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97b2cc10>, 139760039361664), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9ddb5930>, 139760037918272), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9ddb5930>, 139760037918272), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca8111450>, 139760039354784), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca8111450>, 139760039354784), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9e644280>, 139760039355344), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9e644280>, 139760039355344), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9e66fb50>, 139760039347264), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9e66fb50>, 139760039347264), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9dce71f0>, 139760039355024), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9dce71f0>, 139760039355024), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9f85eb00>, 139760038320992), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9f85eb00>, 139760038320992), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9fdddc00>, 139760038325312), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9fdddc00>, 139760038325312), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cf6a356f0>, 139760038327232), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cf6a356f0>, 139760038327232), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cf6a36bf0>, 139760038321712), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cf6a36bf0>, 139760038321712), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(100352, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97b2cc10>, 139760039361664), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(100352, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97b2cc10>, 139760039361664), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9ddb5930>, 139760037918272), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9ddb5930>, 139760037918272), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca8111450>, 139760039354784), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca8111450>, 139760039354784), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9e644280>, 139760039355344), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9e644280>, 139760039355344), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9e66fb50>, 139760039347264), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9e66fb50>, 139760039347264), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9dce71f0>, 139760039355024), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9dce71f0>, 139760039355024), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9f85eb00>, 139760038320992), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9f85eb00>, 139760038320992), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9fdddc00>, 139760038325312), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9fdddc00>, 139760038325312), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cf6a356f0>, 139760038327232), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cf6a356f0>, 139760038327232), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cf6a36bf0>, 139760038321712), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cf6a36bf0>, 139760038321712), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/ResNet50RankingModel_20240614_051402_unfreezed_0.001/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/ResNet50RankingModel_20240614_051402_unfreezed_0.001/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved in tesla_saved_models/ResNet50RankingModel_20240614_051402_unfreezed_0.001 as ResNet50RankingModel_20240614_051402_unfreezed_0.001\n",
      "CPU times: user 16.4 s, sys: 1.19 s, total: 17.6 s\n",
      "Wall time: 17.4 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if SAVE_MODEL:\n",
    "    save_model_with_timestamp_and_lr(model, lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "747f99c9-f867-4cf9-a197-6a607576be02",
   "metadata": {},
   "source": [
    "### Evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4eeec5c5-4431-4c9f-826f-a76ff39000e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = tf.keras.models.load_model('tesla_saved_models/ResNet50RankingModel_20240612_180557_unfreezed_1e-05', \n",
    "                                          custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea664609-73b2-4f2b-a636-151348b4a9cb",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcfef93a-23f6-4efc-a099-efb7f57442a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9921b9ef-b83a-40dc-be25-43218c7e6053",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df1b0f64-7b47-492c-8908-5cd78743369e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b80f1ae-8020-4455-a20a-290f683448a5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e36ed900-8a38-4c40-9de1-18161a4934bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2e6611e-4533-409f-ae39-68356a291d92",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "547d0b9e-4b05-4dff-99bf-9645d9ba048d",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = test_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59051188-f0cf-43e7-95de-cc801d31649e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "plot_metrics(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c0da65a-bd02-491e-87c4-54c0b05bde78",
   "metadata": {},
   "source": [
    "## EfficientNet-B0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45532014-44d8-48ce-b5d1-f3f6970453b3",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "ba674eea-8eab-456a-ac68-85f40a962b07",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.47 s, sys: 44.9 ms, total: 1.51 s\n",
      "Wall time: 1.46 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "loss = tfr.keras.losses.ListMLELoss()\n",
    "lr = 1e-5\n",
    "model = EfficientNetB0RankingModel(loss, trainable=True)\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "12043454-d48b-43e2-8125-477ff4577caf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.98 ms, sys: 128 μs, total: 3.1 ms\n",
      "Wall time: 1.35 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "batch = train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "7fa73572-c388-4551-8461-bb17508019b9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 9s 9s/step\n",
      "CPU times: user 8.53 s, sys: 375 ms, total: 8.91 s\n",
      "Wall time: 8.61 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-12 16:45:36.407364: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12509860583265284084\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[-0.23089248, -0.23231903, -0.22943069, -0.22481644, -0.22886072],\n",
       "       [-0.22775808, -0.22844252, -0.23004653, -0.22690636, -0.22820956],\n",
       "       [-0.2251917 , -0.22465645, -0.2246133 , -0.22573799, -0.22781932],\n",
       "       [-0.22517058, -0.22531664, -0.22607577, -0.2256182 , -0.22651792],\n",
       "       [-0.23030083, -0.24055444, -0.24334359, -0.24203674, -0.24855101],\n",
       "       [-0.22649299, -0.23635752, -0.23390311, -0.22640505, -0.23611635],\n",
       "       [-0.22734389, -0.2318699 , -0.2296015 , -0.23056418, -0.22986478],\n",
       "       [-0.23418827, -0.23546946, -0.23470344, -0.23137027, -0.23955746],\n",
       "       [-0.22813323, -0.22759172, -0.22983454, -0.22563505, -0.22728524],\n",
       "       [-0.23316528, -0.23406503, -0.22786847, -0.23093192, -0.23254393],\n",
       "       [-0.23189795, -0.22653879, -0.2251928 , -0.22556815, -0.22879042],\n",
       "       [-0.24395415, -0.24576074, -0.22736509, -0.23541263, -0.23599729],\n",
       "       [-0.22811395, -0.22645502, -0.22597426, -0.23147124, -0.2327643 ],\n",
       "       [-0.22986782, -0.23059608, -0.22750461, -0.22960919, -0.22767895],\n",
       "       [-0.22609906, -0.23061682, -0.22912407, -0.23275669, -0.22584246],\n",
       "       [-0.22458234, -0.22531171, -0.22332224, -0.22495607, -0.22311938],\n",
       "       [-0.23014809, -0.22913882, -0.23425888, -0.23471828, -0.23021603],\n",
       "       [-0.2262434 , -0.22561674, -0.22562629, -0.22589405, -0.2251144 ],\n",
       "       [-0.23247527, -0.23590855, -0.2395062 , -0.2306858 , -0.23500793],\n",
       "       [-0.22629951, -0.22574387, -0.22601996, -0.22624283, -0.22565818],\n",
       "       [-0.24941725, -0.23688725, -0.24096395, -0.23018502, -0.23443103],\n",
       "       [-0.22744682, -0.23512012, -0.22739094, -0.22767477, -0.22581181],\n",
       "       [-0.2331517 , -0.2329645 , -0.23116218, -0.23426972, -0.23447022],\n",
       "       [-0.227407  , -0.22690427, -0.22603916, -0.22929326, -0.22465222],\n",
       "       [-0.22835588, -0.22473422, -0.22297223, -0.22213152, -0.22718213],\n",
       "       [-0.22614533, -0.22439148, -0.22405672, -0.22276044, -0.22292045],\n",
       "       [-0.22851528, -0.22402704, -0.22517082, -0.22571515, -0.2258856 ],\n",
       "       [-0.22325508, -0.2224171 , -0.22370137, -0.22332303, -0.22483152],\n",
       "       [-0.23319647, -0.23085847, -0.22889198, -0.23077944, -0.22663212],\n",
       "       [-0.23110051, -0.22953902, -0.22975631, -0.22953758, -0.22956161]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.predict(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "51971261-3d86-4f92-939a-f76db09a1e91",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"efficient_net_b0_ranking_model_1\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " efficientnetb0 (Functional  (None, 7, 7, 1280)        4049571   \n",
      " )                                                               \n",
      "                                                                 \n",
      " flatten_2 (Flatten)         multiple                  0         \n",
      "                                                                 \n",
      " sequential_4 (Sequential)   (None, 64)                32285632  \n",
      "                                                                 \n",
      " sequential_5 (Sequential)   (None, 1)                 65        \n",
      "                                                                 \n",
      " ranking_2 (Ranking)         multiple                  0 (unused)\n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 36335268 (138.61 MB)\n",
      "Trainable params: 36293245 (138.45 MB)\n",
      "Non-trainable params: 42023 (164.16 KB)\n",
      "_________________________________________________________________\n",
      "CPU times: user 26.8 ms, sys: 4.31 ms, total: 31.1 ms\n",
      "Wall time: 27 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "a83d2f4e-96f7-4aaf-8e27-0f6139969d34",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "    125/Unknown - 481s 4s/step - ndcg_metric: 0.7479 - mrr_metric: 0.9260 - opa_metric: 0.5912 - loss: 4.5938 - regularization_loss: 0.0000e+00 - total_loss: 4.5938"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-12 16:53:40.614809: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3705798032519791056\n",
      "2024-06-12 16:53:40.614860: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 18307516419292152721\n",
      "2024-06-12 16:53:40.614873: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 149096743755031685\n",
      "2024-06-12 16:53:40.614883: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17470317784980629453\n",
      "2024-06-12 16:53:40.614892: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6055884437706000921\n",
      "2024-06-12 16:53:40.614899: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15868950829799459353\n",
      "2024-06-12 16:53:40.614907: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15664293458172243181\n",
      "2024-06-12 16:53:40.614915: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17207925778548190583\n",
      "2024-06-12 16:53:40.614922: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10980825899782450557\n",
      "2024-06-12 16:53:40.614931: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13374710643810832305\n",
      "2024-06-12 16:53:40.614939: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15924912364333895491\n",
      "2024-06-12 16:53:40.614948: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15278303340487414561\n",
      "2024-06-12 16:53:40.614955: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1402196497191654623\n",
      "2024-06-12 16:53:40.614963: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17514148450805929287\n",
      "2024-06-12 16:53:40.614971: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16879761973467462637\n",
      "2024-06-12 16:53:40.614983: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4768911724608011941\n",
      "2024-06-12 16:53:40.615005: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14285493211701207758\n",
      "2024-06-12 16:53:40.615019: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4162137857189929902\n",
      "2024-06-12 16:53:40.615031: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17149226250694509754\n",
      "2024-06-12 16:53:40.615043: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3370426413565437494\n",
      "2024-06-12 16:53:40.615055: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1945129961133885304\n",
      "2024-06-12 16:53:40.615068: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7896232833376905202\n",
      "2024-06-12 16:53:40.615081: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1355071284010158486\n",
      "2024-06-12 16:53:40.615095: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4850694189786800524\n",
      "2024-06-12 16:53:40.615110: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4431268666220269334\n",
      "2024-06-12 16:53:40.615122: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9594947553012920514\n",
      "2024-06-12 16:53:40.615135: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13350381470575719760\n",
      "2024-06-12 16:53:40.615148: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7587513674150434358\n",
      "2024-06-12 16:53:40.615160: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8177686799467870780\n",
      "2024-06-12 16:53:40.615174: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13137696976191134736\n",
      "2024-06-12 16:53:40.615432: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10237853793303060016\n",
      "2024-06-12 16:53:40.615465: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10550708237394845448\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 484s 4s/step - ndcg_metric: 0.7479 - mrr_metric: 0.9260 - opa_metric: 0.5912 - loss: 4.5932 - regularization_loss: 0.0000e+00 - total_loss: 4.5932\n",
      "Epoch 2/10\n",
      "125/125 [==============================] - 458s 4s/step - ndcg_metric: 0.7745 - mrr_metric: 0.9424 - opa_metric: 0.6405 - loss: 4.4583 - regularization_loss: 0.0000e+00 - total_loss: 4.4583\n",
      "Epoch 3/10\n",
      "125/125 [==============================] - 459s 4s/step - ndcg_metric: 0.7933 - mrr_metric: 0.9511 - opa_metric: 0.6618 - loss: 4.3728 - regularization_loss: 0.0000e+00 - total_loss: 4.3728\n",
      "Epoch 4/10\n",
      "125/125 [==============================] - 460s 4s/step - ndcg_metric: 0.8046 - mrr_metric: 0.9553 - opa_metric: 0.6771 - loss: 4.3001 - regularization_loss: 0.0000e+00 - total_loss: 4.3001\n",
      "Epoch 5/10\n",
      "125/125 [==============================] - 461s 4s/step - ndcg_metric: 0.8123 - mrr_metric: 0.9580 - opa_metric: 0.6871 - loss: 4.2594 - regularization_loss: 0.0000e+00 - total_loss: 4.2594\n",
      "Epoch 6/10\n",
      "125/125 [==============================] - 459s 4s/step - ndcg_metric: 0.8168 - mrr_metric: 0.9599 - opa_metric: 0.6926 - loss: 4.2266 - regularization_loss: 0.0000e+00 - total_loss: 4.2266\n",
      "Epoch 7/10\n",
      "125/125 [==============================] - 461s 4s/step - ndcg_metric: 0.8201 - mrr_metric: 0.9612 - opa_metric: 0.6971 - loss: 4.1954 - regularization_loss: 0.0000e+00 - total_loss: 4.1954\n",
      "Epoch 8/10\n",
      "125/125 [==============================] - 460s 4s/step - ndcg_metric: 0.8218 - mrr_metric: 0.9633 - opa_metric: 0.7002 - loss: 4.1676 - regularization_loss: 0.0000e+00 - total_loss: 4.1676\n",
      "Epoch 9/10\n",
      "125/125 [==============================] - 461s 4s/step - ndcg_metric: 0.8233 - mrr_metric: 0.9647 - opa_metric: 0.7015 - loss: 4.1487 - regularization_loss: 0.0000e+00 - total_loss: 4.1487\n",
      "Epoch 10/10\n",
      "125/125 [==============================] - 460s 4s/step - ndcg_metric: 0.8270 - mrr_metric: 0.9667 - opa_metric: 0.7063 - loss: 4.1296 - regularization_loss: 0.0000e+00 - total_loss: 4.1296\n",
      "CPU times: user 1h 19min 35s, sys: 3min 2s, total: 1h 22min 38s\n",
      "Wall time: 1h 17min 5s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(train_dataset, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "afe5dd20-f797-4d6e-ab0d-33bb6545091c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 6 μs, sys: 0 ns, total: 6 μs\n",
      "Wall time: 11.7 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "7db80746-fc12-4ce5-b4e9-980d7de68356",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n",
      "tf.Tensor([5.5617623 6.133828  5.143933  3.6822789 4.7525983], shape=(5,), dtype=float32) tf.Tensor([3 4 1 0 2], shape=(5,), dtype=int32) tf.Tensor([3 4 2 0 1], shape=(5,), dtype=int32)\n",
      "tf.Tensor([5.2262063 5.161729  5.182261  5.480153  4.0608344], shape=(5,), dtype=float32) tf.Tensor([4 1 2 3 0], shape=(5,), dtype=int32) tf.Tensor([3 1 2 4 0], shape=(5,), dtype=int32)\n",
      "tf.Tensor([4.8499074 3.1742752 5.113752  5.021641  4.6838384], shape=(5,), dtype=float32) tf.Tensor([0 1 4 3 2], shape=(5,), dtype=int32) tf.Tensor([2 0 4 3 1], shape=(5,), dtype=int32)\n",
      "CPU times: user 4.88 s, sys: 288 ms, total: 5.16 s\n",
      "Wall time: 4.35 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = train_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "97e2948e-f7fc-4b83-abbc-661697b0abd9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 4s 4s/step - ndcg_metric: 0.8564 - mrr_metric: 1.0000 - opa_metric: 0.7400 - loss: 3.9284 - regularization_loss: 0.0000e+00 - total_loss: 3.9284\n",
      "Test metrics: {'ndcg_metric': 0.8564373850822449, 'mrr_metric': 1.0, 'opa_metric': 0.7400000095367432, 'loss': 3.9283673763275146, 'regularization_loss': 0, 'total_loss': 3.9283673763275146}\n"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(test_images, test_labels, return_dict=True)\n",
    "print(f\"Test metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "f199a2e6-91c5-4fca-972c-956ecb1b67a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 314ms/step - ndcg_metric: 0.8898 - mrr_metric: 1.0000 - opa_metric: 0.8000 - loss: 3.7221 - regularization_loss: 0.0000e+00 - total_loss: 3.7221\n",
      "Val metrics: {'ndcg_metric': 0.8564373850822449, 'mrr_metric': 1.0, 'opa_metric': 0.7400000095367432, 'loss': 3.9283673763275146, 'regularization_loss': 0, 'total_loss': 3.9283673763275146}\n"
     ]
    }
   ],
   "source": [
    "val_metrics = model.evaluate(val_images, val_labels, return_dict=True)\n",
    "print(f\"Val metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "ba23eb19-b9ff-45fc-b01a-c39f9f971b89",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8999750>, 139766571516752), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8999750>, 139766571516752), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df899b430>, 139766571516272), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df899b430>, 139766571516272), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df89cb970>, 139766509902688), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df89cb970>, 139766509902688), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df89caa10>, 139766509900048), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df89caa10>, 139766509900048), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8aa07f0>, 139766509906368), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8aa07f0>, 139766509906368), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8aa1b70>, 139766509896048), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8aa1b70>, 139766509896048), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8af7820>, 139766535588000), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8af7820>, 139766535588000), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8af5120>, 139766535589120), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8af5120>, 139766535589120), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8af75e0>, 139766571511728), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8af75e0>, 139766571511728), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8ac6650>, 139766571502768), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8ac6650>, 139766571502768), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8999750>, 139766571516752), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8999750>, 139766571516752), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df899b430>, 139766571516272), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df899b430>, 139766571516272), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df89cb970>, 139766509902688), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df89cb970>, 139766509902688), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df89caa10>, 139766509900048), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df89caa10>, 139766509900048), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8aa07f0>, 139766509906368), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8aa07f0>, 139766509906368), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8aa1b70>, 139766509896048), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8aa1b70>, 139766509896048), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8af7820>, 139766535588000), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8af7820>, 139766535588000), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8af5120>, 139766535589120), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8af5120>, 139766535589120), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8af75e0>, 139766571511728), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8af75e0>, 139766571511728), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8ac6650>, 139766571502768), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1df8ac6650>, 139766571502768), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/EfficientNetB0RankingModel_20240612_180246_unfreezed_1e-05/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/EfficientNetB0RankingModel_20240612_180246_unfreezed_1e-05/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved in tesla_saved_models/EfficientNetB0RankingModel_20240612_180246_unfreezed_1e-05 as EfficientNetB0RankingModel_20240612_180246_unfreezed_1e-05\n",
      "CPU times: user 23.9 s, sys: 764 ms, total: 24.7 s\n",
      "Wall time: 24.4 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if SAVE_MODEL:\n",
    "    save_model_with_timestamp_and_lr(model, lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54b9b8e0-ba8c-45a4-9370-115dc0bfcad7",
   "metadata": {},
   "source": [
    "### Evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d0b98f7-30e4-4adf-befb-828fba43aed0",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = tf.keras.models.load_model('tesla_saved_models/EfficientNetB0RankingModel_20240612_150051_unfreezed_0.0001', \n",
    "                                          custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b46e5f55-a6a5-463b-ba18-dca57619f15d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c1ce402-62cb-4f72-bc18-6fefddd377f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cb52c32-6ca6-419e-9e8d-2185f894a54e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bec21fd2-5fc8-466f-bfba-50295626ca40",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6b63713-dc65-4547-8070-fa37047b8a1e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1c2e170-eb50-4f5a-a469-cd7a9aa5a7d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e83ee8b-6e68-4db8-bf5f-bc3941cd9f61",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d770f22d-2f45-4c76-8de5-93872f42d97c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = test_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10ed918a-88bc-4864-8be8-becf94525678",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "plot_metrics(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a9ea75b-c8aa-4b1b-b8e2-25bddde5e1b0",
   "metadata": {},
   "source": [
    "## EfficientNet-B1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65b13a09-c869-4261-b2d9-093d23d6fb45",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f1302208-7b36-483e-a2e0-4bfbec67405c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.69 s, sys: 123 ms, total: 2.81 s\n",
      "Wall time: 2.8 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "loss = tfr.keras.losses.ListMLELoss()\n",
    "model = EfficientNetB1RankingModel(loss, trainable=True)\n",
    "lr = 1e-5\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "c05b4e32-d45b-4ae1-889f-068af6bdfeb5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 688 µs, sys: 3.67 ms, total: 4.36 ms\n",
      "Wall time: 3.02 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "batch = train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "c03b1843-8591-470d-becb-45b5b7166c6b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 14s 14s/step\n",
      "CPU times: user 13 s, sys: 1.26 s, total: 14.3 s\n",
      "Wall time: 13.6 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-11 07:51:42.533564: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14066508150744407545\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[0.48863536, 0.48524892, 0.49265447, 0.49299362, 0.4921349 ],\n",
       "       [0.47797713, 0.4970193 , 0.49148086, 0.5061696 , 0.49895638],\n",
       "       [0.48581207, 0.4879001 , 0.490749  , 0.4828997 , 0.48185697],\n",
       "       [0.47645313, 0.47801498, 0.4776703 , 0.47648677, 0.47703153],\n",
       "       [0.49673116, 0.48790556, 0.48174855, 0.49285626, 0.5070371 ],\n",
       "       [0.48587272, 0.48089215, 0.48402423, 0.48084524, 0.4856478 ],\n",
       "       [0.48972026, 0.49060506, 0.47771817, 0.48629376, 0.4846974 ],\n",
       "       [0.48917884, 0.489912  , 0.49372655, 0.49561727, 0.48928124],\n",
       "       [0.48328078, 0.49528092, 0.5027909 , 0.51053023, 0.4935999 ],\n",
       "       [0.48120716, 0.49607974, 0.48086548, 0.48827684, 0.4822861 ],\n",
       "       [0.48683116, 0.48941952, 0.48581928, 0.48091203, 0.4841023 ],\n",
       "       [0.5113285 , 0.5642727 , 0.5201831 , 0.5152973 , 0.5365057 ],\n",
       "       [0.48198217, 0.4807199 , 0.48309278, 0.47916594, 0.4767786 ],\n",
       "       [0.4755433 , 0.48830062, 0.48498452, 0.48696837, 0.4889611 ],\n",
       "       [0.4919056 , 0.48693937, 0.49178007, 0.47887284, 0.48826185],\n",
       "       [0.4795803 , 0.4792356 , 0.4768992 , 0.47912773, 0.4795985 ],\n",
       "       [0.48164693, 0.48088616, 0.47922832, 0.48134136, 0.4805187 ],\n",
       "       [0.493986  , 0.48817217, 0.48775154, 0.47896254, 0.4913339 ],\n",
       "       [0.5434466 , 0.50825346, 0.5753126 , 0.51261646, 0.53679985],\n",
       "       [0.4846582 , 0.48411736, 0.4999457 , 0.48301068, 0.48443758],\n",
       "       [0.49298936, 0.48188734, 0.50830317, 0.49732402, 0.48529002],\n",
       "       [0.47742394, 0.4779489 , 0.4774568 , 0.47755787, 0.4774792 ],\n",
       "       [0.48558083, 0.48643205, 0.49182245, 0.49179402, 0.4887301 ],\n",
       "       [0.48040012, 0.47967023, 0.47591105, 0.47924465, 0.48022184],\n",
       "       [0.4818664 , 0.48043463, 0.47888076, 0.47930527, 0.4824951 ],\n",
       "       [0.48562503, 0.49102008, 0.4783647 , 0.48787385, 0.49081624],\n",
       "       [0.53699964, 0.51946616, 0.5014612 , 0.5576102 , 0.5151754 ],\n",
       "       [0.491633  , 0.492709  , 0.4824439 , 0.48242489, 0.48625994],\n",
       "       [0.5130697 , 0.49914518, 0.4937545 , 0.500587  , 0.50450873],\n",
       "       [0.49137956, 0.49104673, 0.48506713, 0.48598057, 0.47898883]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.predict(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "8954317d-a9b3-4784-91c6-0cfd126e6fff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"efficient_net_b1_ranking_model\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " efficientnetb1 (Functional  (None, 7, 7, 1280)        6575239   \n",
      " )                                                               \n",
      "                                                                 \n",
      " flatten_2 (Flatten)         multiple                  0         \n",
      "                                                                 \n",
      " sequential_4 (Sequential)   (None, 64)                32285632  \n",
      "                                                                 \n",
      " sequential_5 (Sequential)   (None, 1)                 65        \n",
      "                                                                 \n",
      " ranking_2 (Ranking)         multiple                  0 (unused)\n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 38860936 (148.24 MB)\n",
      "Trainable params: 38798881 (148.01 MB)\n",
      "Non-trainable params: 62055 (242.41 KB)\n",
      "_________________________________________________________________\n",
      "CPU times: user 29.1 ms, sys: 7.72 ms, total: 36.8 ms\n",
      "Wall time: 34.5 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "a20c4266-7bc1-4821-9769-357027a5572e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-11 07:52:21.854771: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:306] gpu_async_0 cuMemAllocAsync failed to allocate 270950400 bytes: CUDA error: out of memory (CUDA_ERROR_OUT_OF_MEMORY)\n",
      " Reported by CUDA: Free memory/Total memory: 182059008/11546394624\n",
      "2024-06-11 07:52:21.854824: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:311] Stats: Limit:                     10279845888\n",
      "InUse:                     10975064201\n",
      "MaxInUse:                  11181894625\n",
      "NumAllocs:                      114408\n",
      "MaxAllocSize:               1933711360\n",
      "Reserved:                            0\n",
      "PeakReserved:                        0\n",
      "LargestFreeBlock:                    0\n",
      "\n",
      "2024-06-11 07:52:21.855047: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:63] Histogram of current allocation: (allocation_size_in_bytes, nb_allocation_of_that_sizes), ...;\n",
      "2024-06-11 07:52:21.855058: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 1, 33\n",
      "2024-06-11 07:52:21.855065: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 4, 148\n",
      "2024-06-11 07:52:21.855071: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 8, 40\n",
      "2024-06-11 07:52:21.855077: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 12, 8\n",
      "2024-06-11 07:52:21.855083: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 16, 13\n",
      "2024-06-11 07:52:21.855089: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 24, 15\n",
      "2024-06-11 07:52:21.855095: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 32, 6\n",
      "2024-06-11 07:52:21.855101: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 40, 15\n",
      "2024-06-11 07:52:21.855107: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 64, 35\n",
      "2024-06-11 07:52:21.855113: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 80, 21\n",
      "2024-06-11 07:52:21.855118: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 96, 40\n",
      "2024-06-11 07:52:21.855124: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 112, 21\n",
      "2024-06-11 07:52:21.855130: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 120, 3\n",
      "2024-06-11 07:52:21.855136: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 128, 38\n",
      "2024-06-11 07:52:21.855142: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 150, 2\n",
      "2024-06-11 07:52:21.855148: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 160, 40\n",
      "2024-06-11 07:52:21.855154: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 192, 27\n",
      "2024-06-11 07:52:21.855160: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 256, 101\n",
      "2024-06-11 07:52:21.855166: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 320, 59\n",
      "2024-06-11 07:52:21.855172: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 384, 38\n",
      "2024-06-11 07:52:21.855177: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 448, 56\n",
      "2024-06-11 07:52:21.855183: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 512, 97\n",
      "2024-06-11 07:52:21.855189: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 576, 98\n",
      "2024-06-11 07:52:21.855195: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 600, 5\n",
      "2024-06-11 07:52:21.855201: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 768, 72\n",
      "2024-06-11 07:52:21.855218: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 960, 95\n",
      "2024-06-11 07:52:21.855224: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 1024, 200\n",
      "2024-06-11 07:52:21.855231: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 1028, 1\n",
      "2024-06-11 07:52:21.855237: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 1152, 6\n",
      "2024-06-11 07:52:21.855244: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 1280, 24\n",
      "2024-06-11 07:52:21.855252: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 1536, 12\n",
      "2024-06-11 07:52:21.855258: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 1920, 133\n",
      "2024-06-11 07:52:21.855265: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 2048, 136\n",
      "2024-06-11 07:52:21.855271: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 2400, 4\n",
      "2024-06-11 07:52:21.855278: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 2688, 133\n",
      "2024-06-11 07:52:21.855284: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 3000, 2\n",
      "2024-06-11 07:52:21.855290: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 3456, 42\n",
      "2024-06-11 07:52:21.855297: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 3600, 2\n",
      "2024-06-11 07:52:21.855303: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 4096, 77\n",
      "2024-06-11 07:52:21.855310: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 4608, 171\n",
      "2024-06-11 07:52:21.855316: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 4800, 2\n",
      "2024-06-11 07:52:21.855323: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 5120, 16\n",
      "2024-06-11 07:52:21.855329: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 5184, 9\n",
      "2024-06-11 07:52:21.855335: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 6144, 6\n",
      "2024-06-11 07:52:21.855342: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 7680, 19\n",
      "2024-06-11 07:52:21.855348: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 8192, 44\n",
      "2024-06-11 07:52:21.855355: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 8640, 6\n",
      "2024-06-11 07:52:21.855361: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 9216, 6\n",
      "2024-06-11 07:52:21.855367: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 9600, 32\n",
      "2024-06-11 07:52:21.855374: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 13824, 24\n",
      "2024-06-11 07:52:21.855380: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 14400, 6\n",
      "2024-06-11 07:52:21.855387: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 16384, 3\n",
      "2024-06-11 07:52:21.855393: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 17280, 15\n",
      "2024-06-11 07:52:21.855400: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 19200, 2\n",
      "2024-06-11 07:52:21.855562: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 23040, 6\n",
      "2024-06-11 07:52:21.855569: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 24000, 9\n",
      "2024-06-11 07:52:21.855575: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 32768, 9\n",
      "2024-06-11 07:52:21.855582: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 37632, 3\n",
      "2024-06-11 07:52:21.855588: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 38400, 66\n",
      "2024-06-11 07:52:21.855595: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 41472, 6\n",
      "2024-06-11 07:52:21.855601: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 48000, 6\n",
      "2024-06-11 07:52:21.855608: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 57600, 2\n",
      "2024-06-11 07:52:21.855614: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 65536, 18\n",
      "2024-06-11 07:52:21.855621: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 67200, 21\n",
      "2024-06-11 07:52:21.855628: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 69120, 3\n",
      "2024-06-11 07:52:21.855635: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 75264, 42\n",
      "2024-06-11 07:52:21.855642: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 76800, 6\n",
      "2024-06-11 07:52:21.855648: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 86400, 2\n",
      "2024-06-11 07:52:21.855655: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 115200, 21\n",
      "2024-06-11 07:52:21.855661: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 131072, 12\n",
      "2024-06-11 07:52:21.855668: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 147456, 9\n",
      "2024-06-11 07:52:21.855674: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 153600, 36\n",
      "2024-06-11 07:52:21.855681: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 215040, 6\n",
      "2024-06-11 07:52:21.855687: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 221184, 54\n",
      "2024-06-11 07:52:21.855693: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 262144, 21\n",
      "2024-06-11 07:52:21.855700: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 301056, 36\n",
      "2024-06-11 07:52:21.855706: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 516096, 6\n",
      "2024-06-11 07:52:21.855713: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 524288, 15\n",
      "2024-06-11 07:52:21.855719: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 589824, 12\n",
      "2024-06-11 07:52:21.855726: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 602112, 10\n",
      "2024-06-11 07:52:21.855733: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 614400, 6\n",
      "2024-06-11 07:52:21.855739: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 884736, 48\n",
      "2024-06-11 07:52:21.855745: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 1048576, 33\n",
      "2024-06-11 07:52:21.855752: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 1474560, 6\n",
      "2024-06-11 07:52:21.855758: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 1638400, 6\n",
      "2024-06-11 07:52:21.855764: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 2097152, 6\n",
      "2024-06-11 07:52:21.855771: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 2359296, 18\n",
      "2024-06-11 07:52:21.855777: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 2457600, 6\n",
      "2024-06-11 07:52:21.855784: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 4194304, 15\n",
      "2024-06-11 07:52:21.855790: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 8388608, 3\n",
      "2024-06-11 07:52:21.855796: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 9437184, 9\n",
      "2024-06-11 07:52:21.855814: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 45158400, 4\n",
      "2024-06-11 07:52:21.855818: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 90316800, 1\n",
      "2024-06-11 07:52:21.855822: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 91125000, 1\n",
      "2024-06-11 07:52:21.855826: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 120422400, 8\n",
      "2024-06-11 07:52:21.855829: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 128450560, 6\n",
      "2024-06-11 07:52:21.855833: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 180633600, 4\n",
      "2024-06-11 07:52:21.855837: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 205520896, 3\n",
      "2024-06-11 07:52:21.855841: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 240844800, 7\n",
      "2024-06-11 07:52:21.855844: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 270950400, 12\n",
      "2024-06-11 07:52:21.855848: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 722534400, 2\n",
      "2024-06-11 07:52:21.855851: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:66] 735494400, 1\n",
      "2024-06-11 07:52:21.855857: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:97] CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT: 11173625856\n",
      "2024-06-11 07:52:21.855862: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:99] CU_MEMPOOL_ATTR_USED_MEM_CURRENT: 10975064201\n",
      "2024-06-11 07:52:21.855865: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:100] CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH: 11341398016\n",
      "2024-06-11 07:52:21.855869: E external/local_xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc:101] CU_MEMPOOL_ATTR_USED_MEM_HIGH: 11181894625\n",
      "2024-06-11 07:52:21.855901: W tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at fused_batch_norm_op.cc:1565 : RESOURCE_EXHAUSTED: OOM when allocating tensor with shape[150,144,56,56] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator gpu_async_0\n",
      "2024-06-11 07:52:21.855950: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14193180684671067366\n",
      "2024-06-11 07:52:21.855965: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3745759039294146572\n",
      "2024-06-11 07:52:21.855975: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6414131225964677410\n",
      "2024-06-11 07:52:21.855985: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9282331468825315364\n",
      "2024-06-11 07:52:21.855994: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17779155178607095608\n",
      "2024-06-11 07:52:21.856005: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10526712046949757980\n",
      "2024-06-11 07:52:21.856045: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4963139675160378351\n",
      "2024-06-11 07:52:21.856078: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10514427531560354845\n",
      "2024-06-11 07:52:21.856103: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4025622587744388885\n",
      "2024-06-11 07:52:21.856131: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9657826811866359943\n",
      "2024-06-11 07:52:21.856142: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 717668288838585165\n",
      "2024-06-11 07:52:21.856152: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8291110681979611855\n",
      "2024-06-11 07:52:21.856163: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12002692529494707989\n",
      "2024-06-11 07:52:21.856171: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7993245699241636629\n",
      "2024-06-11 07:52:21.856181: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12647817480979387719\n",
      "2024-06-11 07:52:21.856190: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9875057590792624103\n",
      "2024-06-11 07:52:21.856200: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8470366682711005347\n",
      "2024-06-11 07:52:21.856211: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 106593523138091881\n",
      "2024-06-11 07:52:21.856220: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17517541487788383731\n",
      "2024-06-11 07:52:21.856230: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4346427561983794845\n",
      "2024-06-11 07:52:21.856265: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5302582222916547518\n",
      "2024-06-11 07:52:21.856278: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 18178316985059284902\n",
      "2024-06-11 07:52:21.856287: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1267220115457177422\n",
      "2024-06-11 07:52:21.856306: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6772426834693404678\n",
      "2024-06-11 07:52:21.856315: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7389018421398133330\n",
      "2024-06-11 07:52:21.856323: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5546849314300819704\n",
      "2024-06-11 07:52:21.856332: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7074326720381794648\n"
     ]
    },
    {
     "ename": "ResourceExhaustedError",
     "evalue": "Graph execution error:\n\nDetected at node efficient_net_b1_ranking_model/efficientnetb1/block2c_bn/FusedBatchNormV3 defined at (most recent call last):\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/runpy.py\", line 196, in _run_module_as_main\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/runpy.py\", line 86, in _run_code\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel_launcher.py\", line 17, in <module>\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/traitlets/config/application.py\", line 992, in launch_instance\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelapp.py\", line 701, in start\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tornado/platform/asyncio.py\", line 195, in start\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/asyncio/base_events.py\", line 603, in run_forever\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/asyncio/base_events.py\", line 1909, in _run_once\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/asyncio/events.py\", line 80, in _run\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 534, in dispatch_queue\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 523, in process_one\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 429, in dispatch_shell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/ipkernel.py\", line 429, in do_execute\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3051, in run_cell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3106, in _run_cell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3311, in run_cell_async\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3493, in run_ast_nodes\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3553, in run_code\n\n  File \"/tmp/ipykernel_262602/1296180263.py\", line 1, in <module>\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 2517, in run_cell_magic\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/magics/execution.py\", line 1340, in time\n\n  File \"<timed exec>\", line 1, in <module>\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1807, in fit\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1401, in train_function\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1384, in step_function\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1373, in run_step\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/models/base.py\", line 68, in train_step\n\n  File \"/home/gorbuljaal/diploma/models/efficientnetb1.py\", line 81, in compute_loss\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 590, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/diploma/models/efficientnetb0.py\", line 71, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 590, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/functional.py\", line 515, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/functional.py\", line 672, in _run_internal_graph\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/layers/normalization/batch_normalization.py\", line 597, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/layers/normalization/batch_normalization.py\", line 990, in _fused_batch_norm\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/control_flow_util.py\", line 108, in smart_cond\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/layers/normalization/batch_normalization.py\", line 979, in _fused_batch_norm_inference\n\nDetected at node efficient_net_b1_ranking_model/efficientnetb1/block2c_bn/FusedBatchNormV3 defined at (most recent call last):\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/runpy.py\", line 196, in _run_module_as_main\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/runpy.py\", line 86, in _run_code\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel_launcher.py\", line 17, in <module>\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/traitlets/config/application.py\", line 992, in launch_instance\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelapp.py\", line 701, in start\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tornado/platform/asyncio.py\", line 195, in start\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/asyncio/base_events.py\", line 603, in run_forever\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/asyncio/base_events.py\", line 1909, in _run_once\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/asyncio/events.py\", line 80, in _run\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 534, in dispatch_queue\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 523, in process_one\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 429, in dispatch_shell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/ipkernel.py\", line 429, in do_execute\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3051, in run_cell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3106, in _run_cell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3311, in run_cell_async\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3493, in run_ast_nodes\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3553, in run_code\n\n  File \"/tmp/ipykernel_262602/1296180263.py\", line 1, in <module>\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 2517, in run_cell_magic\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/magics/execution.py\", line 1340, in time\n\n  File \"<timed exec>\", line 1, in <module>\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1807, in fit\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1401, in train_function\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1384, in step_function\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1373, in run_step\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/models/base.py\", line 68, in train_step\n\n  File \"/home/gorbuljaal/diploma/models/efficientnetb1.py\", line 81, in compute_loss\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 590, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/diploma/models/efficientnetb0.py\", line 71, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 590, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/functional.py\", line 515, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/functional.py\", line 672, in _run_internal_graph\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/layers/normalization/batch_normalization.py\", line 597, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/layers/normalization/batch_normalization.py\", line 990, in _fused_batch_norm\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/control_flow_util.py\", line 108, in smart_cond\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/layers/normalization/batch_normalization.py\", line 979, in _fused_batch_norm_inference\n\n2 root error(s) found.\n  (0) RESOURCE_EXHAUSTED:  OOM when allocating tensor with shape[150,144,56,56] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator gpu_async_0\n\t [[{{node efficient_net_b1_ranking_model/efficientnetb1/block2c_bn/FusedBatchNormV3}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[ranking_2/broadcast_weights_2/assert_broadcastable/AssertGuard/pivot_f/_61/_189]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (1) RESOURCE_EXHAUSTED:  OOM when allocating tensor with shape[150,144,56,56] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator gpu_async_0\n\t [[{{node efficient_net_b1_ranking_model/efficientnetb1/block2c_bn/FusedBatchNormV3}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n0 successful operations.\n0 derived errors ignored. [Op:__inference_train_function_522051]",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mResourceExhaustedError\u001b[0m                    Traceback (most recent call last)",
      "File \u001b[0;32m<timed exec>:1\u001b[0m\n",
      "File \u001b[0;32m~/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:70\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     67\u001b[0m     filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[1;32m     68\u001b[0m     \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[1;32m     69\u001b[0m     \u001b[38;5;66;03m# `tf.debugging.disable_traceback_filtering()`\u001b[39;00m\n\u001b[0;32m---> 70\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m     71\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m     72\u001b[0m     \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n",
      "File \u001b[0;32m~/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow/python/eager/execute.py:53\u001b[0m, in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m     51\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m     52\u001b[0m   ctx\u001b[38;5;241m.\u001b[39mensure_initialized()\n\u001b[0;32m---> 53\u001b[0m   tensors \u001b[38;5;241m=\u001b[39m pywrap_tfe\u001b[38;5;241m.\u001b[39mTFE_Py_Execute(ctx\u001b[38;5;241m.\u001b[39m_handle, device_name, op_name,\n\u001b[1;32m     54\u001b[0m                                       inputs, attrs, num_outputs)\n\u001b[1;32m     55\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m core\u001b[38;5;241m.\u001b[39m_NotOkStatusException \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m     56\u001b[0m   \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "\u001b[0;31mResourceExhaustedError\u001b[0m: Graph execution error:\n\nDetected at node efficient_net_b1_ranking_model/efficientnetb1/block2c_bn/FusedBatchNormV3 defined at (most recent call last):\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/runpy.py\", line 196, in _run_module_as_main\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/runpy.py\", line 86, in _run_code\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel_launcher.py\", line 17, in <module>\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/traitlets/config/application.py\", line 992, in launch_instance\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelapp.py\", line 701, in start\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tornado/platform/asyncio.py\", line 195, in start\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/asyncio/base_events.py\", line 603, in run_forever\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/asyncio/base_events.py\", line 1909, in _run_once\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/asyncio/events.py\", line 80, in _run\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 534, in dispatch_queue\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 523, in process_one\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 429, in dispatch_shell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/ipkernel.py\", line 429, in do_execute\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3051, in run_cell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3106, in _run_cell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3311, in run_cell_async\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3493, in run_ast_nodes\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3553, in run_code\n\n  File \"/tmp/ipykernel_262602/1296180263.py\", line 1, in <module>\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 2517, in run_cell_magic\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/magics/execution.py\", line 1340, in time\n\n  File \"<timed exec>\", line 1, in <module>\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1807, in fit\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1401, in train_function\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1384, in step_function\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1373, in run_step\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/models/base.py\", line 68, in train_step\n\n  File \"/home/gorbuljaal/diploma/models/efficientnetb1.py\", line 81, in compute_loss\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 590, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/diploma/models/efficientnetb0.py\", line 71, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 590, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/functional.py\", line 515, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/functional.py\", line 672, in _run_internal_graph\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/layers/normalization/batch_normalization.py\", line 597, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/layers/normalization/batch_normalization.py\", line 990, in _fused_batch_norm\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/control_flow_util.py\", line 108, in smart_cond\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/layers/normalization/batch_normalization.py\", line 979, in _fused_batch_norm_inference\n\nDetected at node efficient_net_b1_ranking_model/efficientnetb1/block2c_bn/FusedBatchNormV3 defined at (most recent call last):\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/runpy.py\", line 196, in _run_module_as_main\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/runpy.py\", line 86, in _run_code\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel_launcher.py\", line 17, in <module>\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/traitlets/config/application.py\", line 992, in launch_instance\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelapp.py\", line 701, in start\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tornado/platform/asyncio.py\", line 195, in start\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/asyncio/base_events.py\", line 603, in run_forever\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/asyncio/base_events.py\", line 1909, in _run_once\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/asyncio/events.py\", line 80, in _run\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 534, in dispatch_queue\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 523, in process_one\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 429, in dispatch_shell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/ipkernel.py\", line 429, in do_execute\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3051, in run_cell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3106, in _run_cell\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3311, in run_cell_async\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3493, in run_ast_nodes\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3553, in run_code\n\n  File \"/tmp/ipykernel_262602/1296180263.py\", line 1, in <module>\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 2517, in run_cell_magic\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/IPython/core/magics/execution.py\", line 1340, in time\n\n  File \"<timed exec>\", line 1, in <module>\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1807, in fit\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1401, in train_function\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1384, in step_function\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 1373, in run_step\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/models/base.py\", line 68, in train_step\n\n  File \"/home/gorbuljaal/diploma/models/efficientnetb1.py\", line 81, in compute_loss\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 590, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/diploma/models/efficientnetb0.py\", line 71, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 590, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/functional.py\", line 515, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/functional.py\", line 672, in _run_internal_graph\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/layers/normalization/batch_normalization.py\", line 597, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/layers/normalization/batch_normalization.py\", line 990, in _fused_batch_norm\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/control_flow_util.py\", line 108, in smart_cond\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/layers/normalization/batch_normalization.py\", line 979, in _fused_batch_norm_inference\n\n2 root error(s) found.\n  (0) RESOURCE_EXHAUSTED:  OOM when allocating tensor with shape[150,144,56,56] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator gpu_async_0\n\t [[{{node efficient_net_b1_ranking_model/efficientnetb1/block2c_bn/FusedBatchNormV3}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[ranking_2/broadcast_weights_2/assert_broadcastable/AssertGuard/pivot_f/_61/_189]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (1) RESOURCE_EXHAUSTED:  OOM when allocating tensor with shape[150,144,56,56] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator gpu_async_0\n\t [[{{node efficient_net_b1_ranking_model/efficientnetb1/block2c_bn/FusedBatchNormV3}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n0 successful operations.\n0 derived errors ignored. [Op:__inference_train_function_522051]"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(train_dataset, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "788168f0-d9e5-4594-bd57-19aa28216b3d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 16 µs, sys: 2 µs, total: 18 µs\n",
      "Wall time: 36 µs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "2aff2796-8725-4e5d-aa69-dcef6f78a60a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n",
      "tf.Tensor([0.4926491  0.49213213 0.4886324  0.48524582 0.49300435], shape=(5,), dtype=float32) tf.Tensor([3 4 2 1 0], shape=(5,), dtype=int32) tf.Tensor([4 3 2 1 5], shape=(5,), dtype=int32)\n",
      "tf.Tensor([0.5061717  0.498957   0.49702096 0.49147776 0.47797263], shape=(5,), dtype=float32) tf.Tensor([3 2 4 0 1], shape=(5,), dtype=int32) tf.Tensor([5 4 3 2 1], shape=(5,), dtype=int32)\n",
      "tf.Tensor([0.48290005 0.48790368 0.4818572  0.4858101  0.490746  ], shape=(5,), dtype=float32) tf.Tensor([3 0 4 1 2], shape=(5,), dtype=int32) tf.Tensor([2 4 1 3 5], shape=(5,), dtype=int32)\n",
      "CPU times: user 12.9 s, sys: 761 ms, total: 13.7 s\n",
      "Wall time: 6.51 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = train_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "138964c3-539f-42c1-a043-9275b9a24b53",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 5s 5s/step - ndcg_metric: 0.8559 - mrr_metric: 1.0000 - opa_metric: 0.7440 - loss: 4.2762 - regularization_loss: 0.0000e+00 - total_loss: 4.2762\n",
      "Test metrics: {'ndcg_metric': 0.85591721534729, 'mrr_metric': 1.0, 'opa_metric': 0.7440000176429749, 'loss': 4.276159763336182, 'regularization_loss': 0, 'total_loss': 4.276159763336182}\n"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(test_images, test_labels, return_dict=True)\n",
    "print(f\"Test metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d91d2e8f-d370-4180-a52a-b11c9b9dd1d7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 1s 506ms/step - ndcg_metric: 0.8905 - mrr_metric: 1.0000 - opa_metric: 0.7960 - loss: 3.8017 - regularization_loss: 0.0000e+00 - total_loss: 3.8017\n",
      "Val metrics: {'ndcg_metric': 0.85591721534729, 'mrr_metric': 1.0, 'opa_metric': 0.7440000176429749, 'loss': 4.276159763336182, 'regularization_loss': 0, 'total_loss': 4.276159763336182}\n"
     ]
    }
   ],
   "source": [
    "val_metrics = model.evaluate(val_images, val_labels, return_dict=True)\n",
    "print(f\"Val metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "af2730a1-c24e-4f80-b359-aed6d6b26cd1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a32a2f0>, 140493688022208), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a32a2f0>, 140493688022208), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a329120>, 140494692101184), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a329120>, 140494692101184), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a328af0>, 140504553008992), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a328af0>, 140504553008992), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a4a6860>, 140504552998112), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a4a6860>, 140504552998112), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a4a4ac0>, 140504552998272), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a4a4ac0>, 140504552998272), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a4a6080>, 140504552997872), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a4a6080>, 140504552997872), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3dcd60>, 140493687889536), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3dcd60>, 140493687889536), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3de710>, 140493687889616), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3de710>, 140493687889616), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3dc5b0>, 140493688019728), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3dc5b0>, 140493688019728), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3dd780>, 140493688021168), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3dd780>, 140493688021168), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a32a2f0>, 140493688022208), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a32a2f0>, 140493688022208), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a329120>, 140494692101184), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a329120>, 140494692101184), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a328af0>, 140504553008992), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a328af0>, 140504553008992), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a4a6860>, 140504552998112), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a4a6860>, 140504552998112), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a4a4ac0>, 140504552998272), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a4a4ac0>, 140504552998272), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a4a6080>, 140504552997872), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a4a6080>, 140504552997872), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3dcd60>, 140493687889536), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3dcd60>, 140493687889536), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3de710>, 140493687889616), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3de710>, 140493687889616), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3dc5b0>, 140493688019728), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3dc5b0>, 140493688019728), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3dd780>, 140493688021168), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fc71a3dd780>, 140493688021168), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: saved_models/EfficientNetB1RankingModel_20240611_075229/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: saved_models/EfficientNetB1RankingModel_20240611_075229/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved in saved_models/EfficientNetB1RankingModel_20240611_075229 as EfficientNetB1RankingModel_20240611_075229\n",
      "CPU times: user 46.9 s, sys: 4.29 s, total: 51.2 s\n",
      "Wall time: 41 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if SAVE_MODEL:\n",
    "    save_model_with_timestamp_and_lr(model, lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9db5b6c-d1c5-4860-a6b4-4471a8817996",
   "metadata": {},
   "source": [
    "### Evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a155cdef-c724-459d-827c-379091a57d79",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = tf.keras.models.load_model('tesla_saved_models/EfficientNetB1RankingModel_20240612_162253_unfreezed_0.0001', \n",
    "                                          custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce88df78-70b1-4422-a506-291498bd3bb1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b83b1b4b-efbc-44f8-b6b5-5b487e8813ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66e8d97a-1df0-4ded-8d7f-7193ad2b1506",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db1bbd52-f3a6-4058-95e0-ac3e7533e575",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c431c899-7d5b-4c0a-a9a4-945201eac666",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb175c31-b3ef-4acf-8fc2-3c5a8caf34d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fce40c6-c79c-4c45-88ab-f423d1f06b53",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4071a3e-3e3a-4809-93d0-35bdd2c43de5",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = test_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "985ef6e8-98a7-41cd-88e6-c22f31cc10b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "plot_metrics(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88c3a5d6-5b64-40fd-8276-806800790b41",
   "metadata": {},
   "source": [
    "## EfficientNet-B2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2015e0c4-39dc-4880-bd26-9dc502f79c2b",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "6a57eb08-1d46-4d01-960f-660340281f69",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.2 s, sys: 28.4 ms, total: 2.23 s\n",
      "Wall time: 2.15 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "loss = tfr.keras.losses.ListMLELoss()\n",
    "model = EfficientNetB2RankingModel(loss, trainable=True)\n",
    "lr = 1e-5\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "1736a81b-0814-48f8-bede-b10c85bbfc5e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3.18 ms, sys: 132 μs, total: 3.31 ms\n",
      "Wall time: 1.54 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "batch = train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "f54da18e-0b31-4d42-bb48-d845e8e9f92c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 10s 10s/step\n",
      "CPU times: user 9.63 s, sys: 305 ms, total: 9.94 s\n",
      "Wall time: 9.62 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-12 18:03:22.571923: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12236039780566329570\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[-0.7773818 , -0.77809066, -0.77971023, -0.78285104, -0.77724135],\n",
       "       [-0.7755964 , -0.7778571 , -0.7767492 , -0.7772635 , -0.78017116],\n",
       "       [-0.77766955, -0.77654576, -0.7782363 , -0.7759888 , -0.77649605],\n",
       "       [-0.7775134 , -0.77745754, -0.7772269 , -0.7780468 , -0.7782681 ],\n",
       "       [-0.7809101 , -0.7933723 , -0.7817878 , -0.7795198 , -0.8162883 ],\n",
       "       [-0.7819762 , -0.77871746, -0.7816886 , -0.7777947 , -0.78164566],\n",
       "       [-0.78096753, -0.7814978 , -0.7837143 , -0.781403  , -0.78239465],\n",
       "       [-0.79599077, -0.7870174 , -0.7841808 , -0.79142326, -0.78367966],\n",
       "       [-0.77648056, -0.7791383 , -0.77619183, -0.77838004, -0.77735186],\n",
       "       [-0.77505696, -0.78869975, -0.7835237 , -0.77932763, -0.7916694 ],\n",
       "       [-0.7770834 , -0.7787132 , -0.77708083, -0.7782062 , -0.7764028 ],\n",
       "       [-0.77771664, -0.77933997, -0.7859715 , -0.78425574, -0.79995406],\n",
       "       [-0.78077865, -0.780666  , -0.7781638 , -0.78191835, -0.7773582 ],\n",
       "       [-0.7774368 , -0.7802912 , -0.77740794, -0.7758064 , -0.7790513 ],\n",
       "       [-0.7833033 , -0.7789551 , -0.77711576, -0.77784336, -0.7776076 ],\n",
       "       [-0.777769  , -0.7762181 , -0.7774154 , -0.7761536 , -0.77716964],\n",
       "       [-0.7838948 , -0.7890384 , -0.7817035 , -0.78195244, -0.7799636 ],\n",
       "       [-0.7791064 , -0.77793634, -0.7808076 , -0.7782302 , -0.77782613],\n",
       "       [-0.7961519 , -0.77751505, -0.7782936 , -0.77757215, -0.77683896],\n",
       "       [-0.77774817, -0.7788057 , -0.7753139 , -0.77831316, -0.7762413 ],\n",
       "       [-0.7873942 , -0.7834159 , -0.7909962 , -0.7854275 , -0.7822611 ],\n",
       "       [-0.7757127 , -0.7790548 , -0.7799685 , -0.7781225 , -0.78087026],\n",
       "       [-0.7810495 , -0.7870135 , -0.7785851 , -0.78468174, -0.78180397],\n",
       "       [-0.7779414 , -0.7798981 , -0.77883846, -0.7746892 , -0.77744764],\n",
       "       [-0.77871454, -0.776567  , -0.7802477 , -0.7763076 , -0.7775276 ],\n",
       "       [-0.77867264, -0.7776665 , -0.7772653 , -0.7776439 , -0.7774075 ],\n",
       "       [-0.78114057, -0.78024065, -0.77841353, -0.77925736, -0.7800921 ],\n",
       "       [-0.7778436 , -0.7763322 , -0.7769538 , -0.7815331 , -0.7765378 ],\n",
       "       [-0.7822211 , -0.7867166 , -0.776551  , -0.78907394, -0.77470016],\n",
       "       [-0.77819276, -0.7818895 , -0.7820047 , -0.777077  , -0.77799505]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.predict(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "e5874a0d-442e-4321-94a8-ad05f5d66819",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"efficient_net_b2_ranking_model_1\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " efficientnetb2 (Functional  (None, 7, 7, 1408)        7768569   \n",
      " )                                                               \n",
      "                                                                 \n",
      " flatten_3 (Flatten)         multiple                  0         \n",
      "                                                                 \n",
      " sequential_6 (Sequential)   (None, 64)                35496896  \n",
      "                                                                 \n",
      " sequential_7 (Sequential)   (None, 1)                 65        \n",
      "                                                                 \n",
      " ranking_3 (Ranking)         multiple                  0 (unused)\n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 43265530 (165.04 MB)\n",
      "Trainable params: 43197955 (164.79 MB)\n",
      "Non-trainable params: 67575 (263.97 KB)\n",
      "_________________________________________________________________\n",
      "CPU times: user 35.4 ms, sys: 110 μs, total: 35.5 ms\n",
      "Wall time: 33 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "573199c1-66e0-4c2c-b8da-646eb701cf47",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "    125/Unknown - 513s 4s/step - ndcg_metric: 0.7585 - mrr_metric: 0.9317 - opa_metric: 0.6017 - loss: 4.6101 - regularization_loss: 0.0000e+00 - total_loss: 4.6101"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-12 18:11:58.857271: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14887591350252389113\n",
      "2024-06-12 18:11:58.857335: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8001501654279948869\n",
      "2024-06-12 18:11:58.857353: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15504225495810607993\n",
      "2024-06-12 18:11:58.857367: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13112338438219188855\n",
      "2024-06-12 18:11:58.857381: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3988391146775465645\n",
      "2024-06-12 18:11:58.857396: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17157022900605095531\n",
      "2024-06-12 18:11:58.857410: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5445164720759591845\n",
      "2024-06-12 18:11:58.857424: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11754082295435940963\n",
      "2024-06-12 18:11:58.857439: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17130246446550832393\n",
      "2024-06-12 18:11:58.857454: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11639950169263512289\n",
      "2024-06-12 18:11:58.857468: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13144611441702936683\n",
      "2024-06-12 18:11:58.857481: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1690513768386666883\n",
      "2024-06-12 18:11:58.857495: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 627857077779571273\n",
      "2024-06-12 18:11:58.857509: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4876055975875715949\n",
      "2024-06-12 18:11:58.857522: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2332606075049846619\n",
      "2024-06-12 18:11:58.857537: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17172884051399208649\n",
      "2024-06-12 18:11:58.857551: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14828588873369267875\n",
      "2024-06-12 18:11:58.857565: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1350881258191434095\n",
      "2024-06-12 18:11:58.857578: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4761415440412146013\n",
      "2024-06-12 18:11:58.857599: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3565839061828501696\n",
      "2024-06-12 18:11:58.857613: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11489168142061332420\n",
      "2024-06-12 18:11:58.857627: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6643076985649662212\n",
      "2024-06-12 18:11:58.857640: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5038098455500598408\n",
      "2024-06-12 18:11:58.857653: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11457365190439710984\n",
      "2024-06-12 18:11:58.857666: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11841714151021155192\n",
      "2024-06-12 18:11:58.857680: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6534111393205400974\n",
      "2024-06-12 18:11:58.857693: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8705840371694000060\n",
      "2024-06-12 18:11:58.857706: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7293075784703751104\n",
      "2024-06-12 18:11:58.857720: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3520130154898037408\n",
      "2024-06-12 18:11:58.857736: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5758698750686698054\n",
      "2024-06-12 18:11:58.858053: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15251377608867740056\n",
      "2024-06-12 18:11:58.858086: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15144058799074703814\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 517s 4s/step - ndcg_metric: 0.7585 - mrr_metric: 0.9317 - opa_metric: 0.6017 - loss: 4.6094 - regularization_loss: 0.0000e+00 - total_loss: 4.6094\n",
      "Epoch 2/10\n",
      "125/125 [==============================] - 476s 4s/step - ndcg_metric: 0.7792 - mrr_metric: 0.9455 - opa_metric: 0.6438 - loss: 4.4394 - regularization_loss: 0.0000e+00 - total_loss: 4.4394\n",
      "Epoch 3/10\n",
      "125/125 [==============================] - 480s 4s/step - ndcg_metric: 0.7962 - mrr_metric: 0.9525 - opa_metric: 0.6642 - loss: 4.3553 - regularization_loss: 0.0000e+00 - total_loss: 4.3553\n",
      "Epoch 4/10\n",
      "125/125 [==============================] - 477s 4s/step - ndcg_metric: 0.8047 - mrr_metric: 0.9559 - opa_metric: 0.6737 - loss: 4.3073 - regularization_loss: 0.0000e+00 - total_loss: 4.3073\n",
      "Epoch 5/10\n",
      "125/125 [==============================] - 479s 4s/step - ndcg_metric: 0.8090 - mrr_metric: 0.9584 - opa_metric: 0.6786 - loss: 4.2673 - regularization_loss: 0.0000e+00 - total_loss: 4.2673\n",
      "Epoch 6/10\n",
      "125/125 [==============================] - 477s 4s/step - ndcg_metric: 0.8131 - mrr_metric: 0.9589 - opa_metric: 0.6845 - loss: 4.2392 - regularization_loss: 0.0000e+00 - total_loss: 4.2392\n",
      "Epoch 7/10\n",
      "125/125 [==============================] - 478s 4s/step - ndcg_metric: 0.8176 - mrr_metric: 0.9624 - opa_metric: 0.6896 - loss: 4.2131 - regularization_loss: 0.0000e+00 - total_loss: 4.2131\n",
      "Epoch 8/10\n",
      "125/125 [==============================] - 477s 4s/step - ndcg_metric: 0.8221 - mrr_metric: 0.9631 - opa_metric: 0.6952 - loss: 4.1829 - regularization_loss: 0.0000e+00 - total_loss: 4.1829\n",
      "Epoch 9/10\n",
      "125/125 [==============================] - 476s 4s/step - ndcg_metric: 0.8261 - mrr_metric: 0.9643 - opa_metric: 0.7006 - loss: 4.1537 - regularization_loss: 0.0000e+00 - total_loss: 4.1537\n",
      "Epoch 10/10\n",
      "125/125 [==============================] - 478s 4s/step - ndcg_metric: 0.8298 - mrr_metric: 0.9655 - opa_metric: 0.7044 - loss: 4.1249 - regularization_loss: 0.0000e+00 - total_loss: 4.1249\n",
      "CPU times: user 1h 22min 53s, sys: 3min 32s, total: 1h 26min 25s\n",
      "Wall time: 1h 20min 15s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(train_dataset, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "c5922d44-d666-40b3-998a-d0ff141b5830",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 13 μs, sys: 1e+03 ns, total: 14 μs\n",
      "Wall time: 25.5 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "2b332640-71f4-49ed-8861-8cef52748f63",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n",
      "tf.Tensor([-0.99349564  1.1191123   0.31519493  0.2102942   0.83431655], shape=(5,), dtype=float32) tf.Tensor([0 4 1 2 3], shape=(5,), dtype=int32) tf.Tensor([0 4 2 1 3], shape=(5,), dtype=int32)\n",
      "tf.Tensor([ 0.49095434 -1.4347363   0.82706255  0.61680144  0.6294792 ], shape=(5,), dtype=float32) tf.Tensor([2 0 3 1 4], shape=(5,), dtype=int32) tf.Tensor([1 0 4 2 3], shape=(5,), dtype=int32)\n",
      "tf.Tensor([ 0.25340673  0.4590208   0.55869144 -0.71602017  0.12732488], shape=(5,), dtype=float32) tf.Tensor([2 4 3 1 0], shape=(5,), dtype=int32) tf.Tensor([2 3 4 0 1], shape=(5,), dtype=int32)\n",
      "CPU times: user 5.54 s, sys: 196 ms, total: 5.74 s\n",
      "Wall time: 4.62 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = train_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7b14999b-857e-4308-a7ff-ecedee201586",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-20 16:41:47.140455: W tensorflow/core/framework/op_kernel.cc:1827] INVALID_ARGUMENT: required broadcastable shapes\n",
      "2024-05-20 16:41:47.140547: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17765919092235469636\n",
      "2024-05-20 16:41:47.140577: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16180466714669466760\n",
      "2024-05-20 16:41:47.140607: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5363367690352985740\n",
      "2024-05-20 16:41:47.140634: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7922739778709309023\n",
      "2024-05-20 16:41:47.140661: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17884303820219402077\n",
      "2024-05-20 16:41:47.140687: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14594733506454866002\n",
      "2024-05-20 16:41:47.140722: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9024727008196263698\n",
      "2024-05-20 16:41:47.140747: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5430633925713952070\n",
      "2024-05-20 16:41:47.140775: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17934148031573706251\n",
      "2024-05-20 16:41:47.140802: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13979667810584546258\n",
      "2024-05-20 16:41:47.140828: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5062579726965034151\n",
      "2024-05-20 16:41:47.140859: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15979389616084063448\n",
      "2024-05-20 16:41:47.140882: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13541995116182574497\n",
      "2024-05-20 16:41:47.140907: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12050721633806162456\n",
      "2024-05-20 16:41:47.140929: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9024763136604659399\n",
      "2024-05-20 16:41:47.140951: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5102409302666927876\n",
      "2024-05-20 16:41:47.140973: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2404989996113110502\n",
      "2024-05-20 16:41:47.141005: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8973707270589228119\n",
      "2024-05-20 16:41:47.141032: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9805534592640640856\n",
      "2024-05-20 16:41:47.141059: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14905146073541585711\n",
      "2024-05-20 16:41:47.141086: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1220404975777066103\n",
      "2024-05-20 16:41:47.141115: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6506379758305439495\n",
      "2024-05-20 16:41:47.141142: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10453747883338265636\n",
      "2024-05-20 16:41:47.141166: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 500051992367131838\n",
      "2024-05-20 16:41:47.141188: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 18160231801181763464\n",
      "2024-05-20 16:41:47.141233: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1167859471092955757\n",
      "2024-05-20 16:41:47.141264: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17968732996494883598\n",
      "2024-05-20 16:41:47.141291: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12217227328250199803\n",
      "2024-05-20 16:41:47.141378: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11966987901930880694\n",
      "2024-05-20 16:41:47.141446: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3012336402191846332\n",
      "2024-05-20 16:41:47.141471: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6178435293141748894\n",
      "2024-05-20 16:41:47.141502: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14856947864843996698\n",
      "2024-05-20 16:41:47.141528: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16270213262705168281\n",
      "2024-05-20 16:41:47.141553: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17074910523315638089\n",
      "2024-05-20 16:41:47.141578: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3342919263971169331\n",
      "2024-05-20 16:41:47.141602: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 18357642407811094240\n",
      "2024-05-20 16:41:47.141624: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9403729215086794825\n",
      "2024-05-20 16:41:47.141646: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 675550408405261012\n",
      "2024-05-20 16:41:47.141669: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2280305513452870726\n",
      "2024-05-20 16:41:47.141693: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12448809275361877192\n",
      "2024-05-20 16:41:47.141722: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4303484117615763199\n",
      "2024-05-20 16:41:47.141747: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10914480893132382898\n",
      "2024-05-20 16:41:47.141773: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13027351583920760140\n",
      "2024-05-20 16:41:47.141799: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4488123098235807907\n",
      "2024-05-20 16:41:47.141824: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7327255064288799007\n",
      "2024-05-20 16:41:47.141862: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7209942794607127288\n",
      "2024-05-20 16:41:47.141895: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7112853797529382087\n",
      "2024-05-20 16:41:47.141921: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1038216722818478218\n",
      "2024-05-20 16:41:47.141946: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10739681444248154323\n",
      "2024-05-20 16:41:47.141971: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8803655558260126675\n",
      "2024-05-20 16:41:47.141995: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15599935219844560840\n",
      "2024-05-20 16:41:47.142024: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8901278872550684449\n",
      "2024-05-20 16:41:47.142049: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4300736716435667572\n",
      "2024-05-20 16:41:47.142074: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15809016469517941582\n",
      "2024-05-20 16:41:47.142099: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15185949182888393782\n",
      "2024-05-20 16:41:47.142127: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14367989791829639095\n",
      "2024-05-20 16:41:47.142153: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12586305324587508270\n",
      "2024-05-20 16:41:47.142176: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11411635535505630165\n",
      "2024-05-20 16:41:47.142202: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9200397630078444999\n",
      "2024-05-20 16:41:47.142227: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13290011788378637883\n",
      "2024-05-20 16:41:47.142252: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14082391278945282306\n",
      "2024-05-20 16:41:47.142277: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13832650391218342509\n",
      "2024-05-20 16:41:47.142302: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7048821500006919249\n",
      "2024-05-20 16:41:47.142329: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4395734140755496974\n",
      "2024-05-20 16:41:47.142353: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3268753030542512553\n",
      "2024-05-20 16:41:47.142379: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1642986860809728468\n",
      "2024-05-20 16:41:47.142411: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1937553945724490496\n",
      "2024-05-20 16:41:47.142433: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17330584828113040122\n",
      "2024-05-20 16:41:47.142461: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 137052750878218257\n",
      "2024-05-20 16:41:47.142489: W tensorflow/core/framework/op_kernel.cc:1827] INVALID_ARGUMENT: required broadcastable shapes\n",
      "2024-05-20 16:41:47.142544: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 18212498145869142288\n",
      "2024-05-20 16:41:47.142564: W tensorflow/core/framework/op_kernel.cc:1827] INVALID_ARGUMENT: required broadcastable shapes\n",
      "2024-05-20 16:41:47.142576: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7475597136174136005\n",
      "2024-05-20 16:41:47.142601: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2570443591838513769\n",
      "2024-05-20 16:41:47.142626: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7878098204796902691\n",
      "2024-05-20 16:41:47.142650: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7551535249785388509\n",
      "2024-05-20 16:41:47.142716: W tensorflow/core/framework/op_kernel.cc:1827] INVALID_ARGUMENT: required broadcastable shapes\n",
      "2024-05-20 16:41:49.657061: W tensorflow/core/framework/op_kernel.cc:1827] INVALID_ARGUMENT: required broadcastable shapes\n"
     ]
    },
    {
     "ename": "InvalidArgumentError",
     "evalue": "Graph execution error:\n\nDetected at node replica_1/ranking/sort_by_scores/SelectV2_1 defined at (most recent call last):\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/threading.py\", line 973, in _bootstrap\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/threading.py\", line 1016, in _bootstrap_inner\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 2037, in run_step\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/models/base.py\", line 88, in test_step\n\n  File \"/home/gorbuljaal/diploma/models/efficientnetb2.py\", line 86, in compute_loss\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 92, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 97, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 98, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/metrics_utils.py\", line 77, in decorated\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/metrics/base_metric.py\", line 140, in update_state_fn\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/keras/metrics.py\", line 190, in update_state\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/metrics_impl.py\", line 291, in compute\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/metrics_impl.py\", line 655, in _compute_impl\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 155, in sort_by_scores\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 156, in sort_by_scores\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 108, in _get_shuffle_indices\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 109, in _get_shuffle_indices\n\nDetected at node replica_1/ranking/sort_by_scores/SelectV2_1 defined at (most recent call last):\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/threading.py\", line 973, in _bootstrap\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/threading.py\", line 1016, in _bootstrap_inner\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 2037, in run_step\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/models/base.py\", line 88, in test_step\n\n  File \"/home/gorbuljaal/diploma/models/efficientnetb2.py\", line 86, in compute_loss\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 92, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 97, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 98, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/metrics_utils.py\", line 77, in decorated\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/metrics/base_metric.py\", line 140, in update_state_fn\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/keras/metrics.py\", line 190, in update_state\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/metrics_impl.py\", line 291, in compute\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/metrics_impl.py\", line 655, in _compute_impl\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 155, in sort_by_scores\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 156, in sort_by_scores\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 108, in _get_shuffle_indices\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 109, in _get_shuffle_indices\n\nDetected at node replica_1/ranking/sort_by_scores/SelectV2_1 defined at (most recent call last):\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/threading.py\", line 973, in _bootstrap\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/threading.py\", line 1016, in _bootstrap_inner\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 2037, in run_step\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/models/base.py\", line 88, in test_step\n\n  File \"/home/gorbuljaal/diploma/models/efficientnetb2.py\", line 86, in compute_loss\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 92, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 97, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 98, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/metrics_utils.py\", line 77, in decorated\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/metrics/base_metric.py\", line 140, in update_state_fn\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/keras/metrics.py\", line 190, in update_state\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/metrics_impl.py\", line 291, in compute\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/metrics_impl.py\", line 655, in _compute_impl\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 155, in sort_by_scores\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 156, in sort_by_scores\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 108, in _get_shuffle_indices\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 109, in _get_shuffle_indices\n\n3 root error(s) found.\n  (0) INVALID_ARGUMENT:  required broadcastable shapes\n\t [[{{node replica_1/ranking/sort_by_scores/SelectV2_1}}]]\n\t [[cond/then/_0/cond/cond_1/output/_743/_550]]\n\t [[div_no_nan/ReadVariableOp_3/_966]]\n  (1) INVALID_ARGUMENT:  required broadcastable shapes\n\t [[{{node replica_1/ranking/sort_by_scores/SelectV2_1}}]]\n\t [[cond/then/_0/cond/cond_1/output/_743/_550]]\n  (2) INVALID_ARGUMENT:  required broadcastable shapes\n\t [[{{node replica_1/ranking/sort_by_scores/SelectV2_1}}]]\n0 successful operations.\n0 derived errors ignored. [Op:__inference_test_function_195821]",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mInvalidArgumentError\u001b[0m                      Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[14], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m test_metrics \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_images\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_labels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTest metrics: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtest_metrics\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m~/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:70\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     67\u001b[0m     filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[1;32m     68\u001b[0m     \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[1;32m     69\u001b[0m     \u001b[38;5;66;03m# `tf.debugging.disable_traceback_filtering()`\u001b[39;00m\n\u001b[0;32m---> 70\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m     71\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m     72\u001b[0m     \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n",
      "File \u001b[0;32m~/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow/python/eager/execute.py:53\u001b[0m, in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m     51\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m     52\u001b[0m   ctx\u001b[38;5;241m.\u001b[39mensure_initialized()\n\u001b[0;32m---> 53\u001b[0m   tensors \u001b[38;5;241m=\u001b[39m pywrap_tfe\u001b[38;5;241m.\u001b[39mTFE_Py_Execute(ctx\u001b[38;5;241m.\u001b[39m_handle, device_name, op_name,\n\u001b[1;32m     54\u001b[0m                                       inputs, attrs, num_outputs)\n\u001b[1;32m     55\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m core\u001b[38;5;241m.\u001b[39m_NotOkStatusException \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m     56\u001b[0m   \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "\u001b[0;31mInvalidArgumentError\u001b[0m: Graph execution error:\n\nDetected at node replica_1/ranking/sort_by_scores/SelectV2_1 defined at (most recent call last):\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/threading.py\", line 973, in _bootstrap\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/threading.py\", line 1016, in _bootstrap_inner\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 2037, in run_step\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/models/base.py\", line 88, in test_step\n\n  File \"/home/gorbuljaal/diploma/models/efficientnetb2.py\", line 86, in compute_loss\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 92, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 97, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 98, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/metrics_utils.py\", line 77, in decorated\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/metrics/base_metric.py\", line 140, in update_state_fn\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/keras/metrics.py\", line 190, in update_state\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/metrics_impl.py\", line 291, in compute\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/metrics_impl.py\", line 655, in _compute_impl\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 155, in sort_by_scores\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 156, in sort_by_scores\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 108, in _get_shuffle_indices\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 109, in _get_shuffle_indices\n\nDetected at node replica_1/ranking/sort_by_scores/SelectV2_1 defined at (most recent call last):\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/threading.py\", line 973, in _bootstrap\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/threading.py\", line 1016, in _bootstrap_inner\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 2037, in run_step\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/models/base.py\", line 88, in test_step\n\n  File \"/home/gorbuljaal/diploma/models/efficientnetb2.py\", line 86, in compute_loss\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 92, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 97, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 98, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/metrics_utils.py\", line 77, in decorated\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/metrics/base_metric.py\", line 140, in update_state_fn\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/keras/metrics.py\", line 190, in update_state\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/metrics_impl.py\", line 291, in compute\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/metrics_impl.py\", line 655, in _compute_impl\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 155, in sort_by_scores\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 156, in sort_by_scores\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 108, in _get_shuffle_indices\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 109, in _get_shuffle_indices\n\nDetected at node replica_1/ranking/sort_by_scores/SelectV2_1 defined at (most recent call last):\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/threading.py\", line 973, in _bootstrap\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/threading.py\", line 1016, in _bootstrap_inner\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/training.py\", line 2037, in run_step\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/models/base.py\", line 88, in test_step\n\n  File \"/home/gorbuljaal/diploma/models/efficientnetb2.py\", line 86, in compute_loss\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 65, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/engine/base_layer.py\", line 1149, in __call__\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py\", line 96, in error_handler\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 92, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 97, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_recommenders/tasks/ranking.py\", line 98, in call\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/utils/metrics_utils.py\", line 77, in decorated\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/keras/src/metrics/base_metric.py\", line 140, in update_state_fn\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/keras/metrics.py\", line 190, in update_state\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/metrics_impl.py\", line 291, in compute\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/metrics_impl.py\", line 655, in _compute_impl\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 155, in sort_by_scores\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 156, in sort_by_scores\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 108, in _get_shuffle_indices\n\n  File \"/home/gorbuljaal/.conda/envs/diploma/lib/python3.10/site-packages/tensorflow_ranking/python/utils.py\", line 109, in _get_shuffle_indices\n\n3 root error(s) found.\n  (0) INVALID_ARGUMENT:  required broadcastable shapes\n\t [[{{node replica_1/ranking/sort_by_scores/SelectV2_1}}]]\n\t [[cond/then/_0/cond/cond_1/output/_743/_550]]\n\t [[div_no_nan/ReadVariableOp_3/_966]]\n  (1) INVALID_ARGUMENT:  required broadcastable shapes\n\t [[{{node replica_1/ranking/sort_by_scores/SelectV2_1}}]]\n\t [[cond/then/_0/cond/cond_1/output/_743/_550]]\n  (2) INVALID_ARGUMENT:  required broadcastable shapes\n\t [[{{node replica_1/ranking/sort_by_scores/SelectV2_1}}]]\n0 successful operations.\n0 derived errors ignored. [Op:__inference_test_function_195821]"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(test_images, test_labels, return_dict=True)\n",
    "print(f\"Test metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8827a19b-118b-4bf7-b10d-f6bd44160904",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_metrics = model.evaluate(val_images, val_labels, return_dict=True)\n",
    "print(f\"Val metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "96e8acdd-fe5a-442a-a447-f71a1461a92d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(68992, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9c2d540>, 139766507189664), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(68992, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9c2d540>, 139766507189664), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9c2e2c0>, 139766507187504), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9c2e2c0>, 139766507187504), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9c2f100>, 139766433876208), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9c2f100>, 139766433876208), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cfa2c0>, 139766433889488), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cfa2c0>, 139766433889488), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cf9ab0>, 139766433884848), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cf9ab0>, 139766433884848), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cf9630>, 139766433884928), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cf9630>, 139766433884928), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cfae60>, 139766433889088), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cfae60>, 139766433889088), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9bd2980>, 139766433881248), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9bd2980>, 139766433881248), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9bd10f0>, 139766433878208), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9bd10f0>, 139766433878208), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9bd1120>, 139766433885328), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9bd1120>, 139766433885328), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(68992, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9c2d540>, 139766507189664), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(68992, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9c2d540>, 139766507189664), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9c2e2c0>, 139766507187504), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9c2e2c0>, 139766507187504), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9c2f100>, 139766433876208), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9c2f100>, 139766433876208), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cfa2c0>, 139766433889488), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cfa2c0>, 139766433889488), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cf9ab0>, 139766433884848), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cf9ab0>, 139766433884848), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cf9630>, 139766433884928), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cf9630>, 139766433884928), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cfae60>, 139766433889088), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9cfae60>, 139766433889088), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9bd2980>, 139766433881248), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9bd2980>, 139766433881248), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9bd10f0>, 139766433878208), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9bd10f0>, 139766433878208), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9bd1120>, 139766433885328), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1de9bd1120>, 139766433885328), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/EfficientNetB2RankingModel_20240612_192343_unfreezed_1e-05/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/EfficientNetB2RankingModel_20240612_192343_unfreezed_1e-05/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved in tesla_saved_models/EfficientNetB2RankingModel_20240612_192343_unfreezed_1e-05 as EfficientNetB2RankingModel_20240612_192343_unfreezed_1e-05\n",
      "CPU times: user 34.1 s, sys: 939 ms, total: 35.1 s\n",
      "Wall time: 34.7 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if SAVE_MODEL:\n",
    "    save_model_with_timestamp_and_lr(model, lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e53f95fc-bc42-4a5e-895b-ef64bfd77e33",
   "metadata": {},
   "source": [
    "### Evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93f1885f-4537-4ab7-aa23-5920db1a5dc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = tf.keras.models.load_model('tesla_saved_models/EfficientNetB2RankingModel_20240612_162156_unfreezed_0.0001', \n",
    "                                          custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d73029c-33e3-4dc9-bbb0-edb679a2b9ea",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7325ec4d-5e7e-4a38-a911-5f35b72d7031",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a048ba19-9cce-42a1-9583-93642fd43cb9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0fe478c-3044-4eac-8a8c-f9773cc54f3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f3ccef-6b53-4fa5-847c-08a103934ae7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29b9de54-99c7-48d5-8a35-60cafaab2410",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c67a2e1-c048-440c-a94b-bdaaf9331234",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c23afae5-d7b2-46d3-a443-f6a19c69acf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = test_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c32a764-9fee-442c-8a14-688a1c9398e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "plot_metrics(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05e6876c-c613-4215-8e07-1831ecc1de65",
   "metadata": {},
   "source": [
    "## MobileNetV2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "345dabdd-6bfb-4f17-8315-329477a99ffb",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 223,
   "id": "9f2ba7fd-373a-4ece-902e-8b2fd2a871de",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.03 s, sys: 11.8 ms, total: 1.04 s\n",
      "Wall time: 999 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "loss = tfr.keras.losses.ListMLELoss()\n",
    "model = MobileNetV2RankingModel(loss, trainable=True)\n",
    "lr = 1e-6\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 224,
   "id": "fe2bfd3f-4674-4033-b940-8847da3e392f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.91 ms, sys: 125 μs, total: 3.03 ms\n",
      "Wall time: 1.41 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "batch = train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 225,
   "id": "90bdbb9c-e9e3-42d2-9589-d31b072ee4f8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 8s 8s/step\n",
      "CPU times: user 7.91 s, sys: 519 ms, total: 8.43 s\n",
      "Wall time: 8.16 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-13 19:59:50.275483: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17304393383681730748\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[2.143233  , 0.9432556 , 1.0862399 , 0.94055927, 2.80851   ],\n",
       "       [1.5322427 , 1.5489473 , 1.4798476 , 1.106171  , 2.4911616 ],\n",
       "       [2.592333  , 2.465049  , 1.4243081 , 2.3979075 , 2.689603  ],\n",
       "       [2.745817  , 2.2013452 , 2.2769206 , 3.0165296 , 1.3754466 ],\n",
       "       [0.79224217, 1.2226317 , 1.1704174 , 1.4912685 , 2.2323823 ],\n",
       "       [1.6916718 , 1.9854337 , 0.68894833, 2.782864  , 0.897815  ],\n",
       "       [2.4008884 , 1.3125663 , 1.3350594 , 1.2806475 , 2.9795537 ],\n",
       "       [1.1316828 , 0.67548585, 0.9705803 , 1.9435129 , 1.6352355 ],\n",
       "       [1.515203  , 2.039849  , 1.336577  , 1.6055931 , 2.2540967 ],\n",
       "       [0.83313316, 1.3011795 , 0.8993101 , 2.2642138 , 1.6383263 ],\n",
       "       [1.6354178 , 1.7646323 , 1.9832821 , 1.4762974 , 2.1257174 ],\n",
       "       [1.9724741 , 1.3359352 , 0.967936  , 1.1763414 , 0.82226515],\n",
       "       [2.0250194 , 0.72335434, 2.390897  , 2.0985181 , 1.535354  ],\n",
       "       [1.6653976 , 2.1655707 , 2.604019  , 2.4213753 , 1.8036151 ],\n",
       "       [2.1179593 , 1.500165  , 1.4393764 , 1.8202478 , 1.2483344 ],\n",
       "       [1.8888103 , 2.276077  , 0.79105103, 0.38763517, 1.9785675 ],\n",
       "       [2.3019226 , 1.3854228 , 2.2595384 , 1.3365227 , 1.1619247 ],\n",
       "       [2.0656195 , 1.3159323 , 1.8650231 , 1.9908497 , 1.5526389 ],\n",
       "       [0.83465093, 1.463366  , 0.5262041 , 0.5373869 , 1.3183748 ],\n",
       "       [1.9068997 , 1.8502873 , 1.9436438 , 2.1666563 , 1.7822819 ],\n",
       "       [1.9957349 , 1.9655442 , 2.4542239 , 1.7176741 , 0.5958042 ],\n",
       "       [2.2642589 , 2.2935638 , 2.3507037 , 1.7707939 , 1.8654662 ],\n",
       "       [2.3671775 , 1.8437917 , 1.5752937 , 2.578784  , 1.5036416 ],\n",
       "       [1.9328501 , 2.421222  , 1.7456541 , 1.701083  , 2.7616534 ],\n",
       "       [1.491605  , 1.0976712 , 1.0584545 , 2.5984876 , 1.16502   ],\n",
       "       [2.627873  , 1.5570441 , 2.285331  , 1.1406536 , 2.7706141 ],\n",
       "       [2.9231007 , 1.8511457 , 1.6916256 , 2.7785623 , 2.4537432 ],\n",
       "       [1.32276   , 1.3056438 , 2.6263654 , 1.5631658 , 1.6235579 ],\n",
       "       [2.4761684 , 1.8668615 , 1.2893455 , 0.94607633, 0.9254132 ],\n",
       "       [2.6071153 , 1.8130071 , 2.914587  , 1.3011105 , 1.1218418 ]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 225,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.predict(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 226,
   "id": "c040507c-2014-4cea-821f-71661cbbe624",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"mobile_net_v2_ranking_model_1\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " mobilenetv2_1.00_224 (Func  (None, 7, 7, 1280)        2257984   \n",
      " tional)                                                         \n",
      "                                                                 \n",
      " flatten_5 (Flatten)         multiple                  0         \n",
      "                                                                 \n",
      " sequential_10 (Sequential)  (None, 64)                32285632  \n",
      "                                                                 \n",
      " sequential_11 (Sequential)  (None, 1)                 65        \n",
      "                                                                 \n",
      " ranking_5 (Ranking)         multiple                  0 (unused)\n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 34543681 (131.77 MB)\n",
      "Trainable params: 34509569 (131.64 MB)\n",
      "Non-trainable params: 34112 (133.25 KB)\n",
      "_________________________________________________________________\n",
      "CPU times: user 34.9 ms, sys: 31 μs, total: 34.9 ms\n",
      "Wall time: 32.6 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 227,
   "id": "0b5d0151-155a-4484-83e7-3e88f25ec51f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "    125/Unknown - 459s 3s/step - ndcg_metric: 0.7698 - mrr_metric: 0.9515 - opa_metric: 0.6190 - loss: 4.6207 - regularization_loss: 0.0000e+00 - total_loss: 4.6207"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-13 20:07:33.140354: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15114943260701245815\n",
      "2024-06-13 20:07:33.140419: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9657179392313623693\n",
      "2024-06-13 20:07:33.140439: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16693319048054662895\n",
      "2024-06-13 20:07:33.140456: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8096078330843432751\n",
      "2024-06-13 20:07:33.140470: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15404705991176161545\n",
      "2024-06-13 20:07:33.140495: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4663905333049871283\n",
      "2024-06-13 20:07:33.140512: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17836382546822663339\n",
      "2024-06-13 20:07:33.140528: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16244531770958515673\n",
      "2024-06-13 20:07:33.140545: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3049653229894782251\n",
      "2024-06-13 20:07:33.140562: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 751386356433997959\n",
      "2024-06-13 20:07:33.140578: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15822277143874251671\n",
      "2024-06-13 20:07:33.140595: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17511763576662635231\n",
      "2024-06-13 20:07:33.140611: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12700623285875739925\n",
      "2024-06-13 20:07:33.140627: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 749002257459992451\n",
      "2024-06-13 20:07:33.140643: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10655674866177505557\n",
      "2024-06-13 20:07:33.140658: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5166906271864593159\n",
      "2024-06-13 20:07:33.140674: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7301457943667211311\n",
      "2024-06-13 20:07:33.140690: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4219160613621357877\n",
      "2024-06-13 20:07:33.140716: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3184980622549794900\n",
      "2024-06-13 20:07:33.140733: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 888424606397120838\n",
      "2024-06-13 20:07:33.140749: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15697505531988165444\n",
      "2024-06-13 20:07:33.140765: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9866230275292554312\n",
      "2024-06-13 20:07:33.140782: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 423997211039810910\n",
      "2024-06-13 20:07:33.140798: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7391758955667464420\n",
      "2024-06-13 20:07:33.140813: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17304393383681730748\n",
      "2024-06-13 20:07:33.140829: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6760217070533274890\n",
      "2024-06-13 20:07:33.140843: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12666583633844358604\n",
      "2024-06-13 20:07:33.140859: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3152995912012083000\n",
      "2024-06-13 20:07:33.140875: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1770810675395260570\n",
      "2024-06-13 20:07:33.140906: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5332069376591909946\n",
      "2024-06-13 20:07:33.141066: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 534414480989533416\n",
      "2024-06-13 20:07:33.141088: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8357363976823593134\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 463s 4s/step - ndcg_metric: 0.7698 - mrr_metric: 0.9515 - opa_metric: 0.6190 - loss: 4.6193 - regularization_loss: 0.0000e+00 - total_loss: 4.6193\n",
      "Epoch 2/10\n",
      "125/125 [==============================] - 446s 4s/step - ndcg_metric: 0.8199 - mrr_metric: 0.9729 - opa_metric: 0.7010 - loss: 4.1312 - regularization_loss: 0.0000e+00 - total_loss: 4.1312\n",
      "Epoch 3/10\n",
      "125/125 [==============================] - 447s 4s/step - ndcg_metric: 0.8502 - mrr_metric: 0.9828 - opa_metric: 0.7420 - loss: 3.8413 - regularization_loss: 0.0000e+00 - total_loss: 3.8413\n",
      "Epoch 4/10\n",
      "125/125 [==============================] - 447s 4s/step - ndcg_metric: 0.8699 - mrr_metric: 0.9881 - opa_metric: 0.7717 - loss: 3.6091 - regularization_loss: 0.0000e+00 - total_loss: 3.6091\n",
      "Epoch 5/10\n",
      "125/125 [==============================] - 447s 4s/step - ndcg_metric: 0.8859 - mrr_metric: 0.9921 - opa_metric: 0.7953 - loss: 3.4060 - regularization_loss: 0.0000e+00 - total_loss: 3.4060\n",
      "Epoch 6/10\n",
      "125/125 [==============================] - 440s 4s/step - ndcg_metric: 0.8988 - mrr_metric: 0.9964 - opa_metric: 0.8172 - loss: 3.2224 - regularization_loss: 0.0000e+00 - total_loss: 3.2224\n",
      "Epoch 7/10\n",
      "125/125 [==============================] - 439s 4s/step - ndcg_metric: 0.9106 - mrr_metric: 0.9984 - opa_metric: 0.8366 - loss: 3.0515 - regularization_loss: 0.0000e+00 - total_loss: 3.0515\n",
      "Epoch 8/10\n",
      "125/125 [==============================] - 436s 3s/step - ndcg_metric: 0.9206 - mrr_metric: 0.9989 - opa_metric: 0.8537 - loss: 2.8912 - regularization_loss: 0.0000e+00 - total_loss: 2.8912\n",
      "Epoch 9/10\n",
      "125/125 [==============================] - 445s 4s/step - ndcg_metric: 0.9298 - mrr_metric: 0.9991 - opa_metric: 0.8698 - loss: 2.7389 - regularization_loss: 0.0000e+00 - total_loss: 2.7389\n",
      "Epoch 10/10\n",
      "125/125 [==============================] - 444s 4s/step - ndcg_metric: 0.9381 - mrr_metric: 0.9995 - opa_metric: 0.8845 - loss: 2.5937 - regularization_loss: 0.0000e+00 - total_loss: 2.5937\n",
      "CPU times: user 1h 15min 38s, sys: 2min 47s, total: 1h 18min 26s\n",
      "Wall time: 1h 14min 15s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(train_dataset, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 228,
   "id": "377613a0-644f-4283-8a43-672073d68ff0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 16 μs, sys: 0 ns, total: 16 μs\n",
      "Wall time: 29.6 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 229,
   "id": "795b3cf2-939a-4327-b444-e31fe84ed1b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n",
      "tf.Tensor([ 2.5409212  3.8749108 -0.3890914  1.4947778  3.920985 ], shape=(5,), dtype=float32) tf.Tensor([2 3 0 1 4], shape=(5,), dtype=int32) tf.Tensor([2 3 0 1 4], shape=(5,), dtype=int32)\n",
      "tf.Tensor([3.3002145 2.1477892 3.03918   3.29352   0.7454258], shape=(5,), dtype=float32) tf.Tensor([4 1 2 3 0], shape=(5,), dtype=int32) tf.Tensor([4 1 2 3 0], shape=(5,), dtype=int32)\n",
      "tf.Tensor([1.7644613 4.204424  3.3347094 1.4096589 4.4512434], shape=(5,), dtype=float32) tf.Tensor([0 3 2 1 4], shape=(5,), dtype=int32) tf.Tensor([1 3 2 0 4], shape=(5,), dtype=int32)\n",
      "CPU times: user 4.27 s, sys: 261 ms, total: 4.53 s\n",
      "Wall time: 4.04 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = train_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "7a613d48-e87b-4029-8c49-ad7271d4b161",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 2s 2s/step - ndcg_metric: 0.9036 - mrr_metric: 1.0000 - opa_metric: 0.8200 - loss: 12.8485 - regularization_loss: 0.0000e+00 - total_loss: 12.8485\n",
      "Test metrics: {'ndcg_metric': 0.9036268591880798, 'mrr_metric': 1.0, 'opa_metric': 0.8199999928474426, 'loss': 12.848492622375488, 'regularization_loss': 0, 'total_loss': 12.848492622375488}\n"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(test_images, test_labels, return_dict=True)\n",
    "print(f\"Test metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "e24be2c4-fd75-44af-a855-ebf73351f8fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 261ms/step - ndcg_metric: 0.9040 - mrr_metric: 1.0000 - opa_metric: 0.8160 - loss: 12.2982 - regularization_loss: 0.0000e+00 - total_loss: 12.2982\n",
      "Val metrics: {'ndcg_metric': 0.9036268591880798, 'mrr_metric': 1.0, 'opa_metric': 0.8199999928474426, 'loss': 12.848492622375488, 'regularization_loss': 0, 'total_loss': 12.848492622375488}\n"
     ]
    }
   ],
   "source": [
    "val_metrics = model.evaluate(val_images, val_labels, return_dict=True)\n",
    "print(f\"Val metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 230,
   "id": "01b55ea4-7685-4a6d-9c28-bea559eabed2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8337f640>, 139760432170048), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8337f640>, 139760432170048), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8337ed70>, 139759905954096), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8337ed70>, 139759905954096), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339ffa0>, 139760424997616), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339ffa0>, 139760424997616), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339d690>, 139760424993856), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339d690>, 139760424993856), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339e560>, 139760431962816), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339e560>, 139760431962816), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339c5e0>, 139760431969056), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339c5e0>, 139760431969056), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336fca0>, 139760431958416), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336fca0>, 139760431958416), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336de70>, 139760431958496), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336de70>, 139760431958496), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336e5c0>, 139759906744368), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336e5c0>, 139759906744368), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336c5e0>, 139759906747808), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336c5e0>, 139759906747808), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8337f640>, 139760432170048), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(62720, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8337f640>, 139760432170048), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8337ed70>, 139759905954096), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8337ed70>, 139759905954096), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339ffa0>, 139760424997616), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339ffa0>, 139760424997616), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339d690>, 139760424993856), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339d690>, 139760424993856), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339e560>, 139760431962816), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339e560>, 139760431962816), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339c5e0>, 139760431969056), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8339c5e0>, 139760431969056), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336fca0>, 139760431958416), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336fca0>, 139760431958416), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336de70>, 139760431958496), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336de70>, 139760431958496), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336e5c0>, 139759906744368), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336e5c0>, 139759906744368), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336c5e0>, 139759906747808), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8336c5e0>, 139759906747808), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/MobileNetV2RankingModel_20240613_211410_unfreezed_1e-06/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/MobileNetV2RankingModel_20240613_211410_unfreezed_1e-06/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved in tesla_saved_models/MobileNetV2RankingModel_20240613_211410_unfreezed_1e-06 as MobileNetV2RankingModel_20240613_211410_unfreezed_1e-06\n",
      "CPU times: user 18.5 s, sys: 1.06 s, total: 19.5 s\n",
      "Wall time: 19.3 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if SAVE_MODEL:\n",
    "    save_model_with_timestamp_and_lr(model, lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57b2ae13-e7d2-476d-9eb6-76341887e39a",
   "metadata": {},
   "source": [
    "### Evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7062626b-a4d8-4366-aedc-89db01c2aba0",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = tf.keras.models.load_model('saved_models/MobileNetV2RankingModel_20240611_192117', \n",
    "                                          custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f80f4f7-48ea-4fff-b9a8-492deb4af2c6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "938bdb77-98f1-4270-b0e8-04bfe1dbfd98",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fee4e36f-d17e-4646-b6ab-36bef00ed57d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50cb950e-7b28-4601-bcb1-f5e3e738406e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e9afb61-fa4b-48ce-8e32-27b010d82cec",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "869d1401-50c1-403f-8df2-121cf4c2db95",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3826cd8c-65b7-402c-b41e-9e87db29e30b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dda2f1e-c949-458c-9fd5-be358532f488",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = test_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dc35819-3073-469f-b5ac-69413294f76b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "plot_metrics(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86c01a73-4dde-4b22-994a-23155d435c0a",
   "metadata": {},
   "source": [
    "## VGG16"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "384e8bc1-a3b8-4b9a-ad66-a39b0cb6c2a4",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 255,
   "id": "f5e04a7f-49ad-44b1-9351-fa2afb42d3f9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 259 ms, sys: 16 ms, total: 275 ms\n",
      "Wall time: 265 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "loss = tfr.keras.losses.ListMLELoss()\n",
    "model = VGG16RankingModel(loss, trainable=True)\n",
    "lr = 1e-6\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 256,
   "id": "bb753215-c2d3-48cf-a40e-09321eea62ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.76 ms, sys: 112 μs, total: 2.88 ms\n",
      "Wall time: 1.35 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "batch = train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 257,
   "id": "bdb78cb0-a417-42fb-be0d-89313471256c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 7s 7s/step\n",
      "CPU times: user 7.06 s, sys: 291 ms, total: 7.35 s\n",
      "Wall time: 7.13 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 02:31:43.536386: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15390714029690532465\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[-0.26244977, -0.12680954, -0.28612444, -0.15588874, -0.33936387],\n",
       "       [-0.15316296, -0.19989483, -0.18946646, -0.10229386, -0.1045378 ],\n",
       "       [-0.3064242 , -0.33019271, -0.26497138, -0.26756474, -0.19484074],\n",
       "       [-0.27799335, -0.07545626, -0.11345806, -0.09431823, -0.21325426],\n",
       "       [-0.04386499, -0.04839117, -0.27125716, -0.09745226, -0.32407326],\n",
       "       [-0.24422373, -0.40645817, -0.21593395, -0.06779587, -0.30401924],\n",
       "       [-0.15746664, -0.1956679 , -0.26226908, -0.3548264 , -0.15589371],\n",
       "       [-0.13890791, -0.1716529 , -0.22093779, -0.34755248, -0.12463316],\n",
       "       [-0.2039752 , -0.26802295, -0.13848478, -0.27653545, -0.08640008],\n",
       "       [-0.44597796, -0.3412125 , -0.17686535, -0.377585  , -0.28058758],\n",
       "       [-0.17364953,  0.06696859, -0.24239504, -0.06650545, -0.26569086],\n",
       "       [ 0.06296456, -0.2265068 , -0.11868587, -0.34187278, -0.10818244],\n",
       "       [-0.04626235, -0.18199065, -0.15491158, -0.25777775, -0.11847284],\n",
       "       [-0.18931006, -0.02034704, -0.14060014, -0.22072178, -0.12367889],\n",
       "       [-0.05694667, -0.32793742, -0.2061264 , -0.27146226, -0.13328046],\n",
       "       [-0.00554734, -0.3468829 , -0.08966803, -0.15996921, -0.23022306],\n",
       "       [-0.13473977, -0.01482856, -0.17873819, -0.18238032, -0.1127388 ],\n",
       "       [-0.28950503, -0.3731026 , -0.29872212, -0.1305946 , -0.20396802],\n",
       "       [-0.203229  , -0.17200004,  0.00537011,  0.10469806, -0.10237473],\n",
       "       [-0.09974219, -0.1434016 , -0.2516464 , -0.24831343, -0.23684493],\n",
       "       [-0.23738602, -0.41530964, -0.46807945, -0.23645078, -0.04194271],\n",
       "       [-0.3834561 , -0.19533414, -0.12514338, -0.23738137, -0.30426097],\n",
       "       [-0.01600974, -0.07638773, -0.33040327, -0.34026247, -0.17635688],\n",
       "       [-0.27849707, -0.18382308, -0.24327901, -0.33855474, -0.24688146],\n",
       "       [-0.21661979, -0.27729392, -0.17758451, -0.11881302, -0.36126733],\n",
       "       [-0.14888632, -0.22868963, -0.12785164, -0.2442649 , -0.2321683 ],\n",
       "       [-0.1270408 , -0.25086278, -0.31331402, -0.18555701, -0.23800188],\n",
       "       [-0.34132218, -0.18699981, -0.13999122, -0.3113305 , -0.24034359],\n",
       "       [-0.27888077, -0.22598161, -0.32545227, -0.05622336, -0.30414426],\n",
       "       [-0.10829107, -0.23460513, -0.2810108 , -0.29347903, -0.2916003 ]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 257,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.predict(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 258,
   "id": "8bf6f793-b65d-427d-bb06-c4a7ca770e85",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vgg16_ranking_model_1\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " vgg16 (Functional)          (None, 7, 7, 512)         14714688  \n",
      "                                                                 \n",
      " flatten_9 (Flatten)         multiple                  0         \n",
      "                                                                 \n",
      " sequential_18 (Sequential)  (None, 64)                13018048  \n",
      "                                                                 \n",
      " sequential_19 (Sequential)  (None, 1)                 65        \n",
      "                                                                 \n",
      " ranking_9 (Ranking)         multiple                  0 (unused)\n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 27732801 (105.79 MB)\n",
      "Trainable params: 27732801 (105.79 MB)\n",
      "Non-trainable params: 0 (0.00 Byte)\n",
      "_________________________________________________________________\n",
      "CPU times: user 17.6 ms, sys: 219 μs, total: 17.8 ms\n",
      "Wall time: 15.7 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 259,
   "id": "adf7f33e-b4b3-4a0d-bd4a-7c056ba5e5d3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "125/125 [==============================] - 502s 4s/step - ndcg_metric: 0.7767 - mrr_metric: 0.9519 - opa_metric: 0.6357 - loss: 4.5107 - regularization_loss: 0.0000e+00 - total_loss: 4.5107\n",
      "Epoch 2/10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 02:40:05.674011: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12108299114392806563\n",
      "2024-06-14 02:40:05.674088: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12638169692605933059\n",
      "2024-06-14 02:40:05.674111: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2450169764168108559\n",
      "2024-06-14 02:40:05.674128: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13857713164677058033\n",
      "2024-06-14 02:40:05.674144: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6032736105667095887\n",
      "2024-06-14 02:40:05.674160: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10470155243942874895\n",
      "2024-06-14 02:40:05.674176: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7040201953395781749\n",
      "2024-06-14 02:40:05.674192: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15390714029690532465\n",
      "2024-06-14 02:40:05.674208: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2071412137989468789\n",
      "2024-06-14 02:40:05.674225: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2756944078021571663\n",
      "2024-06-14 02:40:05.674240: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12423843672416064827\n",
      "2024-06-14 02:40:05.674258: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16512142192914225443\n",
      "2024-06-14 02:40:05.674275: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10958820116057423595\n",
      "2024-06-14 02:40:05.674292: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5436320334591147883\n",
      "2024-06-14 02:40:05.674318: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7346778733760922658\n",
      "2024-06-14 02:40:05.674335: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1719463962008380450\n",
      "2024-06-14 02:40:05.674352: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1261751261683017318\n",
      "2024-06-14 02:40:05.674369: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11054645442464401340\n",
      "2024-06-14 02:40:05.674385: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12081115441926497608\n",
      "2024-06-14 02:40:05.674401: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10593943642323323310\n",
      "2024-06-14 02:40:05.674417: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 489834972884658942\n",
      "2024-06-14 02:40:05.674433: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10412948733252217480\n",
      "2024-06-14 02:40:05.674449: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 342609418036931834\n",
      "2024-06-14 02:40:05.674464: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12507834396003076598\n",
      "2024-06-14 02:40:05.674480: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17736131097635582946\n",
      "2024-06-14 02:40:05.674496: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7604090730620736746\n",
      "2024-06-14 02:40:05.674512: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13098196097361409802\n",
      "2024-06-14 02:40:05.674528: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13469931511122565144\n",
      "2024-06-14 02:40:05.674543: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16366993509708691838\n",
      "2024-06-14 02:40:05.674559: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9844649450192034500\n",
      "2024-06-14 02:40:05.674575: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3374155953632304766\n",
      "2024-06-14 02:40:05.674642: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6448954824214519480\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 494s 4s/step - ndcg_metric: 0.8163 - mrr_metric: 0.9645 - opa_metric: 0.6992 - loss: 4.1707 - regularization_loss: 0.0000e+00 - total_loss: 4.1707\n",
      "Epoch 3/10\n",
      "125/125 [==============================] - 494s 4s/step - ndcg_metric: 0.8413 - mrr_metric: 0.9724 - opa_metric: 0.7273 - loss: 4.0015 - regularization_loss: 0.0000e+00 - total_loss: 4.0015\n",
      "Epoch 4/10\n",
      "125/125 [==============================] - 493s 4s/step - ndcg_metric: 0.8558 - mrr_metric: 0.9763 - opa_metric: 0.7422 - loss: 3.8867 - regularization_loss: 0.0000e+00 - total_loss: 3.8867\n",
      "Epoch 5/10\n",
      "125/125 [==============================] - 496s 4s/step - ndcg_metric: 0.8653 - mrr_metric: 0.9796 - opa_metric: 0.7535 - loss: 3.7903 - regularization_loss: 0.0000e+00 - total_loss: 3.7903\n",
      "Epoch 6/10\n",
      "125/125 [==============================] - 495s 4s/step - ndcg_metric: 0.8718 - mrr_metric: 0.9831 - opa_metric: 0.7634 - loss: 3.7033 - regularization_loss: 0.0000e+00 - total_loss: 3.7033\n",
      "Epoch 7/10\n",
      "125/125 [==============================] - 495s 4s/step - ndcg_metric: 0.8795 - mrr_metric: 0.9853 - opa_metric: 0.7734 - loss: 3.6197 - regularization_loss: 0.0000e+00 - total_loss: 3.6197\n",
      "Epoch 8/10\n",
      "125/125 [==============================] - 494s 4s/step - ndcg_metric: 0.8852 - mrr_metric: 0.9879 - opa_metric: 0.7816 - loss: 3.5385 - regularization_loss: 0.0000e+00 - total_loss: 3.5385\n",
      "Epoch 9/10\n",
      "125/125 [==============================] - 495s 4s/step - ndcg_metric: 0.8903 - mrr_metric: 0.9893 - opa_metric: 0.7903 - loss: 3.4564 - regularization_loss: 0.0000e+00 - total_loss: 3.4564\n",
      "Epoch 10/10\n",
      "125/125 [==============================] - 493s 4s/step - ndcg_metric: 0.8951 - mrr_metric: 0.9908 - opa_metric: 0.7981 - loss: 3.3731 - regularization_loss: 0.0000e+00 - total_loss: 3.3731\n",
      "CPU times: user 1h 22min 41s, sys: 2min 43s, total: 1h 25min 24s\n",
      "Wall time: 1h 22min 33s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(train_dataset, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 260,
   "id": "0b643be0-cd7b-49e4-b446-aedcde3c5b46",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 7 μs, sys: 0 ns, total: 7 μs\n",
      "Wall time: 14.8 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 261,
   "id": "7ff2cccc-f5c7-470a-8626-cd73e519b537",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n",
      "tf.Tensor([-0.2686391   2.4140956  -1.9536101   1.4114383  -0.27827218], shape=(5,), dtype=float32) tf.Tensor([2 4 0 3 1], shape=(5,), dtype=int32) tf.Tensor([2 4 0 3 1], shape=(5,), dtype=int32)\n",
      "tf.Tensor([ 1.7282531  -0.88702345  0.10940026  0.68998194  1.6007546 ], shape=(5,), dtype=float32) tf.Tensor([3 0 4 1 2], shape=(5,), dtype=int32) tf.Tensor([4 0 1 2 3], shape=(5,), dtype=int32)\n",
      "tf.Tensor([ 1.4030336   1.009131   -2.1816912   0.97378343  0.11508456], shape=(5,), dtype=float32) tf.Tensor([3 2 1 4 0], shape=(5,), dtype=int32) tf.Tensor([4 3 0 2 1], shape=(5,), dtype=int32)\n",
      "CPU times: user 5.02 s, sys: 171 ms, total: 5.19 s\n",
      "Wall time: 4.85 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = train_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "229dea06-8c9a-4988-aaee-85314b4d36ae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 1s 941ms/step - ndcg_metric: 0.9114 - mrr_metric: 1.0000 - opa_metric: 0.8300 - loss: 14.3230 - regularization_loss: 0.0000e+00 - total_loss: 14.3230\n",
      "Test metrics: {'ndcg_metric': 0.9114178419113159, 'mrr_metric': 1.0, 'opa_metric': 0.8299999833106995, 'loss': 14.323023796081543, 'regularization_loss': 0, 'total_loss': 14.323023796081543}\n"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(test_images, test_labels, return_dict=True)\n",
    "print(f\"Test metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "d27860bc-559f-48dc-b0ae-d72f2b4e8b58",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 456ms/step - ndcg_metric: 0.9172 - mrr_metric: 1.0000 - opa_metric: 0.8420 - loss: 13.3479 - regularization_loss: 0.0000e+00 - total_loss: 13.3479\n",
      "Val metrics: {'ndcg_metric': 0.9114178419113159, 'mrr_metric': 1.0, 'opa_metric': 0.8299999833106995, 'loss': 14.323023796081543, 'regularization_loss': 0, 'total_loss': 14.323023796081543}\n"
     ]
    }
   ],
   "source": [
    "val_metrics = model.evaluate(val_images, val_labels, return_dict=True)\n",
    "print(f\"Val metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 262,
   "id": "f0ac01ea-19f6-48a4-98c1-910b1ca4ae36",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8111fbb0>, 139765293182736), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8111fbb0>, 139765293182736), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1da7e8dd80>, 139765293184816), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1da7e8dd80>, 139765293184816), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9c1a2140>, 139765293180576), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9c1a2140>, 139765293180576), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9c1a3d30>, 139765293184416), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9c1a3d30>, 139765293184416), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cf6a37640>, 139765293186016), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cf6a37640>, 139765293186016), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97f6b460>, 139765293181056), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97f6b460>, 139765293181056), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97f68d00>, 139765293177056), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97f68d00>, 139765293177056), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca06ce260>, 139765293182896), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca06ce260>, 139765293182896), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97bd1de0>, 139765291772832), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97bd1de0>, 139765291772832), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca03e37c0>, 139765293178576), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca03e37c0>, 139765293178576), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8111fbb0>, 139765293182736), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8111fbb0>, 139765293182736), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1da7e8dd80>, 139765293184816), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1da7e8dd80>, 139765293184816), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9c1a2140>, 139765293180576), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9c1a2140>, 139765293180576), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9c1a3d30>, 139765293184416), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c9c1a3d30>, 139765293184416), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cf6a37640>, 139765293186016), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1cf6a37640>, 139765293186016), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97f6b460>, 139765293181056), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97f6b460>, 139765293181056), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97f68d00>, 139765293177056), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97f68d00>, 139765293177056), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca06ce260>, 139765293182896), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca06ce260>, 139765293182896), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97bd1de0>, 139765291772832), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c97bd1de0>, 139765291772832), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca03e37c0>, 139765293178576), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca03e37c0>, 139765293178576), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/VGG16RankingModel_20240614_035421_unfreezed_1e-06/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/VGG16RankingModel_20240614_035421_unfreezed_1e-06/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved in tesla_saved_models/VGG16RankingModel_20240614_035421_unfreezed_1e-06 as VGG16RankingModel_20240614_035421_unfreezed_1e-06\n",
      "CPU times: user 7.03 s, sys: 521 ms, total: 7.55 s\n",
      "Wall time: 7.49 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if SAVE_MODEL:\n",
    "    save_model_with_timestamp_and_lr(model, lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c2ee424-e691-4700-b739-4382921da76f",
   "metadata": {},
   "source": [
    "### Evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d630e2cb-d180-400f-8b06-e8de3ea0da30",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = tf.keras.models.load_model('saved_models/VGG16RankingModel_20240611_214054', \n",
    "                                          custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "326a8fce-9637-4640-9fc4-f8e7324f9fe4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c19a5b36-bc72-4e88-b202-9cf81ed35ebb",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8b6c045-d091-4e59-927e-a27e881755ec",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ad900cf-cdab-43c0-9b39-7dd05055e09c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faf1ed12-5c94-4c48-92ca-60ae9f0fb123",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fece7f20-3c4f-49fb-b6b7-a2c58e5e9094",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b451bb2b-0ccc-4630-8cf9-73d1c7c9aa11",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c0ed2e9-cc16-43ee-8b91-9f3a25174884",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = test_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "530d7e95-0660-40aa-a89b-d5e1b5aee5d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "plot_metrics(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26fe757e-5bda-451a-a20b-04e06c8689de",
   "metadata": {},
   "source": [
    "## VGG19"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "554a5919-22e5-4295-8a98-65dde919d557",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 279,
   "id": "3e3aa771-b3ff-4731-a87c-6502d27b3be7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 338 ms, sys: 87.8 ms, total: 426 ms\n",
      "Wall time: 416 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "loss = tfr.keras.losses.ListMLELoss()\n",
    "model = VGG19RankingModel(loss, trainable=True)\n",
    "lr = 1e-6\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)\n",
    "model.compile(optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 280,
   "id": "9afd607a-73d5-4b84-b667-2d4aa60f4ee5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3.95 ms, sys: 159 μs, total: 4.11 ms\n",
      "Wall time: 1.85 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "batch = train_dataset.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 281,
   "id": "15bad203-3718-4863-8e86-e448c0bd9e5c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 7s 7s/step\n",
      "CPU times: user 7.28 s, sys: 273 ms, total: 7.55 s\n",
      "Wall time: 7.28 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 09:25:17.535584: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14764271650021614035\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[ 0.0780599 , -0.21177793, -0.2105574 ,  0.14103517,  0.16088627],\n",
       "       [ 0.12396344,  0.07609893, -0.07047979,  0.16628256, -0.07935337],\n",
       "       [ 0.05026826,  0.45136118,  0.13552755,  0.02198543,  0.14668936],\n",
       "       [ 0.20898205,  0.09741619,  0.31735513,  0.26989424,  0.10721479],\n",
       "       [ 0.04028548, -0.01898882, -0.06260123, -0.2277972 ,  0.06247074],\n",
       "       [ 0.01537807,  0.08909898,  0.2344847 ,  0.10162647, -0.04495493],\n",
       "       [-0.05183747,  0.13719876,  0.11634661, -0.1182182 ,  0.1698629 ],\n",
       "       [ 0.10974462, -0.14096045, -0.08962672, -0.14553975, -0.12000169],\n",
       "       [ 0.0268673 ,  0.03885245, -0.14400718, -0.01615727,  0.06438422],\n",
       "       [ 0.16807479, -0.24134254,  0.04373427, -0.06008611, -0.06539148],\n",
       "       [ 0.18011591,  0.1335904 ,  0.15762052,  0.20916548,  0.08857246],\n",
       "       [-0.0940796 ,  0.31741142, -0.1069942 ,  0.00703353,  0.00714217],\n",
       "       [ 0.19804406,  0.20109896,  0.12005404,  0.17895667,  0.172095  ],\n",
       "       [ 0.03693856,  0.10492587,  0.06794888,  0.2717196 , -0.07400772],\n",
       "       [ 0.0487109 ,  0.2783461 ,  0.22539914,  0.33386093,  0.03008522],\n",
       "       [ 0.2990093 ,  0.02521366,  0.21723405,  0.12593964,  0.22334617],\n",
       "       [ 0.01110077, -0.08425024,  0.00404032, -0.15570767,  0.04160397],\n",
       "       [-0.00999965,  0.05339696, -0.03759487,  0.324629  , -0.03289387],\n",
       "       [ 0.1378361 , -0.03840393, -0.04196721, -0.05663704,  0.1375565 ],\n",
       "       [ 0.1616465 ,  0.03768646,  0.19493218, -0.01841702,  0.04468495],\n",
       "       [ 0.06948616,  0.02583532,  0.24923937,  0.13916178, -0.06695803],\n",
       "       [ 0.2393565 ,  0.14564946,  0.0581629 ,  0.46999735,  0.0586637 ],\n",
       "       [ 0.21466708, -0.09669949,  0.35366014,  0.10720767,  0.27914184],\n",
       "       [ 0.0691562 ,  0.13957962,  0.0345297 ,  0.05530173,  0.13512412],\n",
       "       [ 0.02200617,  0.10041535, -0.13421845,  0.26491868, -0.07395262],\n",
       "       [-0.26945233,  0.12869853,  0.02601011,  0.10342797,  0.04254843],\n",
       "       [-0.2211275 , -0.09292691, -0.25055173,  0.10814391,  0.08190061],\n",
       "       [-0.17398067,  0.13628411,  0.20557424,  0.13689703,  0.09782888],\n",
       "       [ 0.33154815, -0.01491924,  0.00197649, -0.02418731, -0.08667425],\n",
       "       [ 0.19468258,  0.20962079,  0.2427817 ,  0.21250895,  0.24477325]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 281,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.predict(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 282,
   "id": "61dabfac-2ec9-46cf-96c5-efd61bee7dd2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vgg19_ranking_model_2\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " vgg19 (Functional)          (None, 7, 7, 512)         20024384  \n",
      "                                                                 \n",
      " flatten_12 (Flatten)        multiple                  0         \n",
      "                                                                 \n",
      " sequential_24 (Sequential)  (None, 64)                13018048  \n",
      "                                                                 \n",
      " sequential_25 (Sequential)  (None, 1)                 65        \n",
      "                                                                 \n",
      " ranking_12 (Ranking)        multiple                  0 (unused)\n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 33042497 (126.05 MB)\n",
      "Trainable params: 33042497 (126.05 MB)\n",
      "Non-trainable params: 0 (0.00 Byte)\n",
      "_________________________________________________________________\n",
      "CPU times: user 27.5 ms, sys: 3.99 ms, total: 31.5 ms\n",
      "Wall time: 28.6 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 283,
   "id": "f2f2d93b-8db2-4f5e-8953-76d4b1f153ca",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "125/125 [==============================] - 514s 4s/step - ndcg_metric: 0.7757 - mrr_metric: 0.9500 - opa_metric: 0.6364 - loss: 4.4392 - regularization_loss: 0.0000e+00 - total_loss: 4.4392\n",
      "Epoch 2/10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 09:33:51.066181: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1916183192265826380\n",
      "2024-06-14 09:33:51.066225: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14310136826461651500\n",
      "2024-06-14 09:33:51.066241: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12938603152751296629\n",
      "2024-06-14 09:33:51.066249: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14764271650021614035\n",
      "2024-06-14 09:33:51.066258: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 510920707912122881\n",
      "2024-06-14 09:33:51.066266: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2248832325523727073\n",
      "2024-06-14 09:33:51.066274: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 9600531165392667815\n",
      "2024-06-14 09:33:51.066281: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6329732935015875635\n",
      "2024-06-14 09:33:51.066290: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6823935647643298421\n",
      "2024-06-14 09:33:51.066297: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7040655693434899309\n",
      "2024-06-14 09:33:51.066306: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5551121410355842913\n",
      "2024-06-14 09:33:51.066314: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 448322386870341637\n",
      "2024-06-14 09:33:51.066321: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14476043979715703147\n",
      "2024-06-14 09:33:51.066328: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7249414365456429655\n",
      "2024-06-14 09:33:51.066337: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 188325259605078259\n",
      "2024-06-14 09:33:51.066345: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2467532506570675223\n",
      "2024-06-14 09:33:51.066353: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2131939493448365687\n",
      "2024-06-14 09:33:51.066360: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11786493475988608909\n",
      "2024-06-14 09:33:51.066368: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16048947158737582135\n",
      "2024-06-14 09:33:51.066377: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3617827592617327001\n",
      "2024-06-14 09:33:51.066385: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15215156478195366707\n",
      "2024-06-14 09:33:51.066393: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7887546812061628443\n",
      "2024-06-14 09:33:51.066400: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1793107625411944115\n",
      "2024-06-14 09:33:51.066411: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14584362184060071782\n",
      "2024-06-14 09:33:51.066420: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15055823035486638252\n",
      "2024-06-14 09:33:51.066427: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2802802177487401458\n",
      "2024-06-14 09:33:51.066434: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6481909406217040540\n",
      "2024-06-14 09:33:51.066441: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 1941998129770598366\n",
      "2024-06-14 09:33:51.066449: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16812569349086780252\n",
      "2024-06-14 09:33:51.066458: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8755670508418342100\n",
      "2024-06-14 09:33:51.066465: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13005067920986315902\n",
      "2024-06-14 09:33:51.066519: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11822025476306083302\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125/125 [==============================] - 507s 4s/step - ndcg_metric: 0.8260 - mrr_metric: 0.9683 - opa_metric: 0.7082 - loss: 4.1201 - regularization_loss: 0.0000e+00 - total_loss: 4.1201\n",
      "Epoch 3/10\n",
      "125/125 [==============================] - 506s 4s/step - ndcg_metric: 0.8434 - mrr_metric: 0.9727 - opa_metric: 0.7297 - loss: 3.9694 - regularization_loss: 0.0000e+00 - total_loss: 3.9694\n",
      "Epoch 4/10\n",
      "125/125 [==============================] - 507s 4s/step - ndcg_metric: 0.8557 - mrr_metric: 0.9755 - opa_metric: 0.7445 - loss: 3.8560 - regularization_loss: 0.0000e+00 - total_loss: 3.8560\n",
      "Epoch 5/10\n",
      "125/125 [==============================] - 508s 4s/step - ndcg_metric: 0.8654 - mrr_metric: 0.9785 - opa_metric: 0.7567 - loss: 3.7576 - regularization_loss: 0.0000e+00 - total_loss: 3.7576\n",
      "Epoch 6/10\n",
      "125/125 [==============================] - 507s 4s/step - ndcg_metric: 0.8739 - mrr_metric: 0.9821 - opa_metric: 0.7678 - loss: 3.6660 - regularization_loss: 0.0000e+00 - total_loss: 3.6660\n",
      "Epoch 7/10\n",
      "125/125 [==============================] - 508s 4s/step - ndcg_metric: 0.8805 - mrr_metric: 0.9843 - opa_metric: 0.7782 - loss: 3.5755 - regularization_loss: 0.0000e+00 - total_loss: 3.5755\n",
      "Epoch 8/10\n",
      "125/125 [==============================] - 508s 4s/step - ndcg_metric: 0.8879 - mrr_metric: 0.9877 - opa_metric: 0.7895 - loss: 3.4836 - regularization_loss: 0.0000e+00 - total_loss: 3.4836\n",
      "Epoch 9/10\n",
      "125/125 [==============================] - 508s 4s/step - ndcg_metric: 0.8937 - mrr_metric: 0.9901 - opa_metric: 0.7980 - loss: 3.3894 - regularization_loss: 0.0000e+00 - total_loss: 3.3894\n",
      "Epoch 10/10\n",
      "125/125 [==============================] - 507s 4s/step - ndcg_metric: 0.8991 - mrr_metric: 0.9912 - opa_metric: 0.8066 - loss: 3.2922 - regularization_loss: 0.0000e+00 - total_loss: 3.2922\n",
      "CPU times: user 1h 24min 53s, sys: 2min 49s, total: 1h 27min 42s\n",
      "Wall time: 1h 24min 40s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "history = model.fit(train_dataset, epochs=EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 284,
   "id": "851c203a-46c5-437b-b275-79ca40e61714",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 0 ns, sys: 7 μs, total: 7 μs\n",
      "Wall time: 12.9 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 285,
   "id": "1d11f834-2ab5-4b5b-937a-3d0589e77f22",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<_TakeDataset element_spec=(TensorSpec(shape=(None, 5, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 5), dtype=tf.int32, name=None))>\n",
      "tf.Tensor([ 0.7904267   3.2590365  -2.106199    0.09893734  1.6632942 ], shape=(5,), dtype=float32) tf.Tensor([2 4 0 1 3], shape=(5,), dtype=int32) tf.Tensor([2 4 0 1 3], shape=(5,), dtype=int32)\n",
      "tf.Tensor([-0.299446   1.2902373  2.0067987  1.1649668  1.1827627], shape=(5,), dtype=float32) tf.Tensor([0 2 3 4 1], shape=(5,), dtype=int32) tf.Tensor([0 3 4 1 2], shape=(5,), dtype=int32)\n",
      "tf.Tensor([-1.2357972   1.2359915   0.16709673  1.5604701   1.885488  ], shape=(5,), dtype=float32) tf.Tensor([1 4 0 2 3], shape=(5,), dtype=int32) tf.Tensor([0 2 1 3 4], shape=(5,), dtype=int32)\n",
      "CPU times: user 5.61 s, sys: 214 ms, total: 5.82 s\n",
      "Wall time: 5.06 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = train_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "9e4c7344-b923-457f-aa8c-f55ebc647e15",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 1s 967ms/step - ndcg_metric: 0.9023 - mrr_metric: 1.0000 - opa_metric: 0.8020 - loss: 11.8531 - regularization_loss: 0.0000e+00 - total_loss: 11.8531\n",
      "Test metrics: {'ndcg_metric': 0.9022544622421265, 'mrr_metric': 1.0, 'opa_metric': 0.8019999861717224, 'loss': 11.853145599365234, 'regularization_loss': 0, 'total_loss': 11.853145599365234}\n"
     ]
    }
   ],
   "source": [
    "test_metrics = model.evaluate(test_images, test_labels, return_dict=True)\n",
    "print(f\"Test metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "0df068e4-6b08-4a58-8224-c9cb60204f70",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 1s 530ms/step - ndcg_metric: 0.9285 - mrr_metric: 1.0000 - opa_metric: 0.8500 - loss: 10.7363 - regularization_loss: 0.0000e+00 - total_loss: 10.7363\n",
      "Val metrics: {'ndcg_metric': 0.9022544622421265, 'mrr_metric': 1.0, 'opa_metric': 0.8019999861717224, 'loss': 11.853145599365234, 'regularization_loss': 0, 'total_loss': 11.853145599365234}\n"
     ]
    }
   ],
   "source": [
    "val_metrics = model.evaluate(val_images, val_labels, return_dict=True)\n",
    "print(f\"Val metrics: {test_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 286,
   "id": "63e17a83-15c7-4e70-b8e5-a494714a6c36",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60a428f0>, 139762284235232), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60a428f0>, 139762284235232), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca3e81c90>, 139762284244192), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca3e81c90>, 139762284244192), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca3e822f0>, 139760860640032), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca3e822f0>, 139760860640032), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c63fb9f30>, 139760860639952), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c63fb9f30>, 139760860639952), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c63fba1a0>, 139762284243712), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c63fba1a0>, 139762284243712), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca307fee0>, 139762284233472), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca307fee0>, 139762284233472), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b284ca0>, 139760860634512), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b284ca0>, 139760860634512), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60a87340>, 139760860628752), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60a87340>, 139760860628752), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c81d2ff10>, 139762300337440), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c81d2ff10>, 139762300337440), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c63d09840>, 139762300326160), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c63d09840>, 139762300326160), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60a428f0>, 139762284235232), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(25088, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60a428f0>, 139762284235232), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca3e81c90>, 139762284244192), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca3e81c90>, 139762284244192), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca3e822f0>, 139760860640032), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 256), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca3e822f0>, 139760860640032), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c63fb9f30>, 139760860639952), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c63fb9f30>, 139760860639952), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c63fba1a0>, 139762284243712), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c63fba1a0>, 139762284243712), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca307fee0>, 139762284233472), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1ca307fee0>, 139762284233472), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b284ca0>, 139760860634512), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 64), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c8b284ca0>, 139760860634512), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60a87340>, 139760860628752), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c60a87340>, 139760860628752), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c81d2ff10>, 139762300337440), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c81d2ff10>, 139762300337440), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c63d09840>, 139762300326160), {}).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f1c63d09840>, 139762300326160), {}).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/VGG19RankingModel_20240614_105002_unfreezed_1e-06/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: tesla_saved_models/VGG19RankingModel_20240614_105002_unfreezed_1e-06/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved in tesla_saved_models/VGG19RankingModel_20240614_105002_unfreezed_1e-06 as VGG19RankingModel_20240614_105002_unfreezed_1e-06\n",
      "CPU times: user 2.67 s, sys: 435 ms, total: 3.1 s\n",
      "Wall time: 3.05 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if SAVE_MODEL:\n",
    "    save_model_with_timestamp_and_lr(model, lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9d7e380-7e79-4bc0-8bf4-fff60275cba7",
   "metadata": {},
   "source": [
    "### Evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "516cd1ae-544c-4e41-8282-e45e0244f266",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = tf.keras.models.load_model('saved_models/VGG19RankingModel_20240611_205427', \n",
    "                                          custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38e8727d-16f2-4028-a67c-647dd98b2505",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ce89ddb-b057-49ed-8046-90f2197c603c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa1869cf-8bdd-4c9b-b2fd-cdd9924f9a16",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7ee8741-faa6-466b-932d-7499e829a49a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f12c0d7d-f8e7-481f-b868-9444e57265a9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "metrics = compute_dataset_metrics(loaded_model, test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2056c1af-2538-423b-9d8d-799cc6d32d93",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afe7dd64-73f6-41a5-821f-0ebd5e83bb7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "show_on_test_batch = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3ea8de9-76cc-49eb-91e9-cf6e08a68fea",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "if show_on_test_batch:\n",
    "    batch = test_dataset.take(1)\n",
    "    print(batch)\n",
    "    for images, labels in batch:\n",
    "        print(model(images)[0], labels[0], predict_rank(model, images)[0])\n",
    "        print(model(images)[1], labels[1], predict_rank(model, images)[1])\n",
    "        print(model(images)[2], labels[2], predict_rank(model, images)[2])\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3bc413a-8fc8-4428-92d0-f9611f74e668",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "plot_metrics(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "532e8181-e0a5-41c5-aa75-b06fe930d1e4",
   "metadata": {},
   "source": [
    "# Benchmarking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "27308b2b-2aae-4398-a6d8-f35e07b50e15",
   "metadata": {},
   "outputs": [],
   "source": [
    "paths = [\n",
    "    \"saved_models/MobileNetV2RankingModel_20240611_192117\",\n",
    "    \"saved_models/VGG16RankingModel_20240611_214054\",\n",
    "    \"saved_models/VGG19RankingModel_20240611_205427\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a80d591d-de6f-4524-a053-88012105f440",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def test_inference_speed(model_path, batch_sizes):\n",
    "    model = tf.keras.models.load_model(model_path, custom_objects={'ListMLELoss': tfr.keras.losses.ListMLELoss()})\n",
    "    num_benchmark_iterations = 100\n",
    "    \n",
    "    for batch_size in batch_sizes:\n",
    "        timings = []\n",
    "        \n",
    "        for _ in range(num_benchmark_iterations):\n",
    "            #bs = np.random.randint(10, 100)\n",
    "            data = np.random.random((batch_size, 5, 224, 224, 3))\n",
    "            start_time = time.time()\n",
    "            predictions = model(data, training=False)\n",
    "            end_time = time.time()\n",
    "            #sizes.append(bs)\n",
    "            timings.append(end_time - start_time)\n",
    "        \n",
    "        avg_time = np.mean(timings)\n",
    "        std_dev = np.std(timings)\n",
    "        print(f'Batch size is {batch_size}')\n",
    "        print(f'Average inference time per batch: {avg_time:.4f} seconds')\n",
    "        print(f'Standard deviation of inference times: {std_dev:.4f} seconds')\n",
    "        print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4e2e95ae-419c-41f7-bdb2-7d69f027f6b4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model is MobileNetV2RankingModel\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 23:27:22.594390: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31137 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:08:00.0, compute capability: 7.0\n",
      "2024-06-14 23:27:29.498344: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8904\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch size is 25\n",
      "Average inference time per batch: 0.1714 seconds\n",
      "Standard deviation of inference times: 0.2396 seconds\n",
      "\n",
      "Batch size is 50\n",
      "Average inference time per batch: 0.3085 seconds\n",
      "Standard deviation of inference times: 0.2897 seconds\n",
      "\n",
      "Batch size is 75\n",
      "Average inference time per batch: 0.4502 seconds\n",
      "Standard deviation of inference times: 0.3786 seconds\n",
      "\n",
      "Batch size is 100\n",
      "Average inference time per batch: 0.5911 seconds\n",
      "Standard deviation of inference times: 0.4663 seconds\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-14 23:32:25.768625: W tensorflow/core/kernels/gpu_utils.cc:54] Failed to allocate memory for convolution redzone checking; skipping this check. This is benign and only means that we won't check cudnn for out-of-bounds reads and writes. This message will only be printed once.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch size is 200\n",
      "Average inference time per batch: 1.1490 seconds\n",
      "Standard deviation of inference times: 0.7959 seconds\n",
      "\n",
      "Model is VGG16RankingModel\n",
      "Batch size is 25\n",
      "Average inference time per batch: 0.2757 seconds\n",
      "Standard deviation of inference times: 1.3520 seconds\n",
      "\n",
      "Batch size is 50\n",
      "Average inference time per batch: 0.5111 seconds\n",
      "Standard deviation of inference times: 2.3891 seconds\n",
      "\n",
      "Batch size is 75\n",
      "Average inference time per batch: 0.7357 seconds\n",
      "Standard deviation of inference times: 3.2799 seconds\n",
      "\n",
      "Batch size is 100\n",
      "Average inference time per batch: 0.9875 seconds\n",
      "Standard deviation of inference times: 4.2949 seconds\n",
      "\n",
      "Batch size is 200\n",
      "Average inference time per batch: 1.8096 seconds\n",
      "Standard deviation of inference times: 7.3021 seconds\n",
      "\n",
      "Model is VGG19RankingModel\n",
      "Batch size is 25\n",
      "Average inference time per batch: 0.1413 seconds\n",
      "Standard deviation of inference times: 0.0103 seconds\n",
      "\n",
      "Batch size is 50\n",
      "Average inference time per batch: 0.2754 seconds\n",
      "Standard deviation of inference times: 0.0167 seconds\n",
      "\n",
      "Batch size is 75\n",
      "Average inference time per batch: 0.4081 seconds\n",
      "Standard deviation of inference times: 0.0162 seconds\n",
      "\n",
      "Batch size is 100\n",
      "Average inference time per batch: 0.5427 seconds\n",
      "Standard deviation of inference times: 0.0200 seconds\n",
      "\n",
      "Batch size is 200\n",
      "Average inference time per batch: 1.0763 seconds\n",
      "Standard deviation of inference times: 0.0287 seconds\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for path in paths:\n",
    "    print(\"Model is\", path.split('/')[1].split('_')[0])\n",
    "    test_inference_speed(path, [25, 50, 75, 100, 200])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62f9e60a-2def-4534-a047-a2a81a5c5f15",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bb1cf18-9fe4-4f64-b101-995df802617f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (diploma_env second)",
   "language": "python",
   "name": "diploma_second"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
